From ee87dafab396a36bc718b47f450ef2477e3596dc Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 5 Mar 2026 17:27:16 -0800 Subject: [PATCH 1/7] add guard at bisected jax version where lower is segfault Signed-off-by: tdophung --- transformer_engine/jax/triton_extensions/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2627a08929..ff23a263e9 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -162,6 +162,21 @@ def _check_triton_compatibility(): "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) from e +# Minimum jaxlib version required for Triton kernel dispatch to work correctly. +# jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution +# (bisected: last known segfault = jax-2025-09-23, first known pass = jax-2025-09-24). +_JAXLIB_MIN_VERSION = "0.8.0.dev20250924" +import jaxlib # noqa: E402 + +if version.parse(jaxlib.__version__) < version.parse(_JAXLIB_MIN_VERSION): + raise RuntimeError( + f"jaxlib {jaxlib.__version__} is too old for transformer_engine.jax.triton_extensions.\n" + f"Triton kernel dispatch segfaults with jaxlib < {_JAXLIB_MIN_VERSION}.\n" + f"Please upgrade: pip install --upgrade jax jaxlib\n" + f"Or use a JAX nightly container dated 2025-09-24 or later.\n" + f"If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + __all__ = ["triton_call_lowering", "get_triton_info"] From 14666c021fecfa5c89baccbba46ea25c9f99e46c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:35:38 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/triton_extensions/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index ff23a263e9..18aa09c6cd 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -172,9 +172,9 @@ def _check_triton_compatibility(): raise RuntimeError( f"jaxlib {jaxlib.__version__} is too old for transformer_engine.jax.triton_extensions.\n" f"Triton kernel dispatch segfaults with jaxlib < {_JAXLIB_MIN_VERSION}.\n" - f"Please upgrade: pip install --upgrade jax jaxlib\n" - f"Or use a JAX nightly container dated 2025-09-24 or later.\n" - f"If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + "Please upgrade: pip install --upgrade jax jaxlib\n" + "Or use a JAX nightly container dated 2025-09-24 or later.\n" + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) From 4a3f764b7ed8fcf87d458ead42c3c63f80e53d7d Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 9 Mar 2026 15:00:46 -0700 Subject: [PATCH 3/7] Fix pylint: remove unused lru_cache import and fix import order in helper.py Signed-off-by: tdophung --- transformer_engine/jax/quantize/helper.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 7d81d71bc8..2548ec9e46 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -14,11 +14,9 @@ from enum import Enum import hashlib from typing import Optional, Tuple, Dict, Union, Sequence, Type, List -from functools import reduce, lru_cache +from functools import reduce import operator -from importlib.metadata import version as get_pkg_version import warnings -from packaging.version import Version as PkgVersion import jax import jax.numpy as jnp @@ -40,6 +38,7 @@ get_all_mesh_axes, with_sharding_constraint, ) +from transformer_engine.jax.version_utils import _jax_version_meet_requirement from .metadata import QuantizeMeta from .scaling_modes import ScalingMode @@ -68,16 +67,6 @@ NVTE_FP8_COLLECTION_NAME = "fp8_metas" -@lru_cache(maxsize=None) -def _jax_version_meet_requirement(version: str): - """ - Helper function checking if required JAX version is available - """ - jax_version = PkgVersion(get_pkg_version("jax")) - jax_version_required = PkgVersion(version) - return jax_version >= jax_version_required - - def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. From 5be75b20254f30369d3f775d7b9e513390879e6b Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 9 Mar 2026 16:16:01 -0700 Subject: [PATCH 4/7] Guard Triton tests against JAX < 0.8.0 using release version check - Add version_utils.py with is_triton_extension_supported() checking JAX >= 0.8.0 (release version, not dev snapshot) and TRITON_EXTENSION_MIN_JAX_VERSION constant - Add pytest.mark.triton marker and conftest hook to skip marked tests on old JAX - Add require_triton() for module-level skipping in test files - Rewrite triton_extensions to use is_triton_extension_supported() instead of direct jaxlib dev-version comparison Signed-off-by: tdophung --- qa/L0_jax_lint/test.sh | 2 +- tests/jax/conftest.py | 25 +++++++++++ tests/jax/test_distributed_permutation.py | 22 +++++++--- tests/jax/test_distributed_router.py | 24 +++++++++-- tests/jax/test_fused_router.py | 24 +++++++++-- tests/jax/test_permutation.py | 26 ++++++++--- tests/jax/test_triton_custom_calls.py | 5 ++- tests/jax/utils.py | 14 ++++++ .../jax/triton_extensions/__init__.py | 3 ++ .../jax/triton_extensions/utils.py | 25 +++++------ transformer_engine/jax/version_utils.py | 43 +++++++++++++++++++ 11 files changed, 180 insertions(+), 33 deletions(-) mode change 100644 => 100755 qa/L0_jax_lint/test.sh create mode 100644 transformer_engine/jax/version_utils.py diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh old mode 100644 new mode 100755 index 3f804d3ef9..9014ea90e9 --- a/qa/L0_jax_lint/test.sh +++ b/qa/L0_jax_lint/test.sh @@ -4,7 +4,7 @@ set -e -: "${TE_PATH:=/opt/transformerengine}" +: "${TE_PATH:=/Users/tdophung/Repos/TransformerEngine}" pip3 install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 6b7520d147..0e8b62cd22 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -11,6 +11,10 @@ import transformer_engine.jax from transformer_engine_jax import get_device_compute_capability +from transformer_engine.jax.version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) @pytest.fixture(autouse=True, scope="function") @@ -83,5 +87,26 @@ def pytest_sessionfinish(self, session, exitstatus): def pytest_configure(config): + config.addinivalue_line( + "markers", + "triton: mark test (or test class) as requiring JAX Triton kernel support (JAX >= 0.8.0). " + "Apply per test/class with @pytest.mark.triton so non-Triton tests in the same file run on old JAX.", + ) if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1": config.pluginmanager.register(TestTimingPlugin(), "test_timing") + + +def pytest_collection_modifyitems(config, items): + """Skip tests marked 'triton' when JAX is too old for Triton kernel dispatch.""" + if is_triton_extension_supported(): + return + skip_triton = pytest.mark.skip( + reason=( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. " + "Triton kernel dispatch segfaults with older jaxlib. " + "Upgrade with: pip install --upgrade jax jaxlib" + ) + ) + for item in items: + if item.get_closest_marker("triton"): + item.add_marker(skip_triton) diff --git a/tests/jax/test_distributed_permutation.py b/tests/jax/test_distributed_permutation.py index 04ed236e81..53c21da9cf 100644 --- a/tests/jax/test_distributed_permutation.py +++ b/tests/jax/test_distributed_permutation.py @@ -34,11 +34,22 @@ from distributed_test_base import generate_configs from utils import assert_allclose, pytest_parametrize_wrapper -# High-level API with VJP support -from transformer_engine.jax.permutation import ( - token_dispatch, - token_combine, -) + +@pytest.fixture(autouse=True, scope="function") +def _inject_permutation(request): + """Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + import sys + from transformer_engine.jax.permutation import token_dispatch, token_combine + mod = sys.modules[__name__] + mod.token_dispatch = token_dispatch + mod.token_combine = token_combine + yield + + +# High-level API with VJP support (injected by _inject_permutation) # Reference implementations from test_permutation.py from test_permutation import ( @@ -80,6 +91,7 @@ } +@pytest.mark.triton class TestDistributedPermutation: """Test distributed/sharded execution of MoE permutation primitives. diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index 1b3fe14e75..c45efd9c56 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -34,10 +34,23 @@ from distributed_test_base import generate_configs from utils import assert_allclose, pytest_parametrize_wrapper -from transformer_engine.jax.router import ( - fused_topk_with_score_function, - fused_moe_aux_loss, -) + +@pytest.fixture(autouse=True, scope="function") +def _inject_router(request): + """Lazy-load router API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + import sys + from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, + ) + mod = sys.modules[__name__] + mod.fused_topk_with_score_function = fused_topk_with_score_function + mod.fused_moe_aux_loss = fused_moe_aux_loss + yield + jax.config.update("jax_use_shardy_partitioner", True) @@ -68,6 +81,7 @@ } +@pytest.mark.triton class TestDistributedFusedTopk: """Test distributed execution of fused_topk_with_score_function. @@ -200,6 +214,7 @@ def test_distributed_topk( ) +@pytest.mark.triton class TestDistributedScoreForAuxLoss: """Test distributed execution of fused_topk_with_score_function with compute_aux_scores=True. @@ -333,6 +348,7 @@ def test_distributed_score_for_aux_loss( ) +@pytest.mark.triton class TestDistributedMoEAuxLoss: """Test distributed execution of fused_moe_aux_loss. diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index 77e89457c8..a9d0fdcc40 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -4,6 +4,7 @@ """Tests for fused MoE router CUDA kernels (JAX wrappers).""" +import sys from functools import partial from typing import Optional @@ -13,10 +14,21 @@ from utils import pytest_parametrize_wrapper -from transformer_engine.jax.router import ( - fused_topk_with_score_function, - fused_moe_aux_loss, -) + +@pytest.fixture(autouse=True, scope="function") +def _inject_router(request): + """Lazy-load router API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, + ) + mod = sys.modules[__name__] + mod.fused_topk_with_score_function = fused_topk_with_score_function + mod.fused_moe_aux_loss = fused_moe_aux_loss + yield # ============================================================================= # Test case definitions (L0 = fast smoke, L2 = comprehensive) @@ -371,6 +383,7 @@ def loss_fused(logits_): @pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) @pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) @pytest_parametrize_wrapper("enable_bias", ENABLE_BIAS_OPTIONS) +@pytest.mark.triton def test_topk_sigmoid( dtype, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias ): @@ -397,6 +410,7 @@ def test_topk_sigmoid( @pytest_parametrize_wrapper("use_pre_softmax", USE_PRE_SOFTMAX_OPTIONS) @pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) @pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +@pytest.mark.triton def test_topk_softmax( dtype, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor ): @@ -426,6 +440,7 @@ def test_topk_softmax( SCORE_AUX_LOSS_CASES, ) @pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) +@pytest.mark.triton def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): logits = make_logits(num_tokens, num_experts, score_function, dtype) @@ -486,6 +501,7 @@ def loss_fused(logits_): "num_tokens,num_experts,topk", AUX_LOSS_CASES, ) +@pytest.mark.triton def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): key = jax.random.PRNGKey(SEED) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 138a817240..f1e656f619 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -5,20 +5,33 @@ """Tests for permutation Triton kernels and high-level APIs""" import functools +import sys import jax import jax.numpy as jnp import pytest -# High-level API with VJP support -from transformer_engine.jax.permutation import ( - token_dispatch, - token_combine, - sort_chunks_by_index, -) from utils import assert_allclose, pytest_parametrize_wrapper +@pytest.fixture(autouse=True, scope="function") +def _inject_permutation(request): + """Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + from transformer_engine.jax.permutation import ( + token_dispatch, + token_combine, + sort_chunks_by_index, + ) + mod = sys.modules[__name__] + mod.token_dispatch = token_dispatch + mod.token_combine = token_combine + mod.sort_chunks_by_index = sort_chunks_by_index + yield + + ALL_DISPATCH_COMBINE_CASES = [ (128, 5, 128, 3), (1024, 8, 128, 8), @@ -449,6 +462,7 @@ def reference_sort_chunks_by_map( return output, permuted_probs +@pytest.mark.triton class TestHighLevelPermutationAPI: """Test high-level permutation APIs (token_dispatch, token_combine, etc.) diff --git a/tests/jax/test_triton_custom_calls.py b/tests/jax/test_triton_custom_calls.py index 6d969de0d3..139c6b91a6 100644 --- a/tests/jax/test_triton_custom_calls.py +++ b/tests/jax/test_triton_custom_calls.py @@ -7,7 +7,9 @@ import jax.numpy as jnp import pytest -from utils import assert_allclose, pytest_parametrize_wrapper +from utils import assert_allclose, pytest_parametrize_wrapper, require_triton + +require_triton() import triton import triton.language as tl @@ -23,6 +25,7 @@ def init(): yield +@pytest.mark.triton class TestTritonBinding: """Test Triton binding primitive.""" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c22b0a6063..d412796cba 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -26,6 +26,10 @@ make_swa_mask, ) from transformer_engine.jax.quantize.helper import DType as TEDType +from transformer_engine.jax.version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) PRNGKey = Any Shape = Tuple[int, ...] @@ -40,6 +44,16 @@ NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0))) +def require_triton(): + """Skip the current test module if JAX is too old for Triton kernel support.""" + if not is_triton_extension_supported(): + pytest.skip( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. " + "Triton kernel dispatch segfaults with older jaxlib. " + "Upgrade with: pip install --upgrade jax jaxlib" + ) + + def is_devices_enough(required): """ Check if the available GPUs is enough diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index d9708fde9f..150a5fbf12 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -54,6 +54,9 @@ def lowering(ctx, x, **kwargs): from transformer_engine.jax.triton_extensions import get_triton_info info = get_triton_info() print(f"Using Triton {info['version']} from {info['source']}") + + # Check if JAX version supports Triton (without importing triton_extensions) + from transformer_engine.jax.version_utils import is_triton_extension_supported """ from .utils import * diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 18aa09c6cd..f9b3376c0c 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -42,6 +42,11 @@ import jax import jax.numpy as jnp +from ..version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) + # Placeholder package version on PyPI that should never be used _PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1" @@ -162,23 +167,19 @@ def _check_triton_compatibility(): "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) from e -# Minimum jaxlib version required for Triton kernel dispatch to work correctly. -# jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution -# (bisected: last known segfault = jax-2025-09-23, first known pass = jax-2025-09-24). -_JAXLIB_MIN_VERSION = "0.8.0.dev20250924" -import jaxlib # noqa: E402 - -if version.parse(jaxlib.__version__) < version.parse(_JAXLIB_MIN_VERSION): +# Enforce minimum JAX version for Triton kernel dispatch (segfaults on jaxlib < 0.8.0). +if not is_triton_extension_supported(): raise RuntimeError( - f"jaxlib {jaxlib.__version__} is too old for transformer_engine.jax.triton_extensions.\n" - f"Triton kernel dispatch segfaults with jaxlib < {_JAXLIB_MIN_VERSION}.\n" - "Please upgrade: pip install --upgrade jax jaxlib\n" - "Or use a JAX nightly container dated 2025-09-24 or later.\n" + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for " + "transformer_engine.jax.triton_extensions. " + f"Triton kernel dispatch segfaults with older jaxlib. " + f"Current jax version: {jax.__version__}. " + "Please upgrade: pip install --upgrade jax jaxlib. " "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) -__all__ = ["triton_call_lowering", "get_triton_info"] +__all__ = ["triton_call_lowering", "get_triton_info", "is_triton_extension_supported"] # Triton kernel cache (module-level, shared across all kernels) _TRITON_KERNEL_CACHE = {} diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py new file mode 100644 index 0000000000..60997b44d2 --- /dev/null +++ b/transformer_engine/jax/version_utils.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +JAX version helpers. + +Provides version checks for JAX that can be used across TE JAX (quantize, triton +extensions, etc.) without pulling in feature-specific code. +""" + +from functools import lru_cache +from importlib.metadata import version as get_pkg_version + +from packaging.version import Version as PkgVersion + + +@lru_cache(maxsize=None) +def _jax_version_meet_requirement(version: str): + """Return True if the installed JAX version is >= the required version.""" + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + +# Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults). +TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0" + + +def is_triton_extension_supported() -> bool: + """Return True if the current JAX version supports Triton kernel dispatch. + + JAX/jaxlib >= 0.8.0 is required. Older versions segfault when dispatching + Triton kernels. Use this to skip tests or gate features without importing + triton_extensions (which would raise immediately on old jax). + """ + return _jax_version_meet_requirement(TRITON_EXTENSION_MIN_JAX_VERSION) + + +__all__ = [ + "_jax_version_meet_requirement", + "is_triton_extension_supported", + "TRITON_EXTENSION_MIN_JAX_VERSION", +] From bd7b58fe7ea374aa773a5c2e9acacbf5547b9b6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 23:16:59 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/conftest.py | 5 +++-- tests/jax/test_distributed_permutation.py | 1 + tests/jax/test_distributed_router.py | 1 + tests/jax/test_fused_router.py | 2 ++ tests/jax/test_permutation.py | 1 + transformer_engine/jax/triton_extensions/utils.py | 2 +- 6 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 0e8b62cd22..22e1622634 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -89,8 +89,9 @@ def pytest_sessionfinish(self, session, exitstatus): def pytest_configure(config): config.addinivalue_line( "markers", - "triton: mark test (or test class) as requiring JAX Triton kernel support (JAX >= 0.8.0). " - "Apply per test/class with @pytest.mark.triton so non-Triton tests in the same file run on old JAX.", + "triton: mark test (or test class) as requiring JAX Triton kernel support (JAX >= 0.8.0)." + " Apply per test/class with @pytest.mark.triton so non-Triton tests in the same file run on" + " old JAX.", ) if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1": config.pluginmanager.register(TestTimingPlugin(), "test_timing") diff --git a/tests/jax/test_distributed_permutation.py b/tests/jax/test_distributed_permutation.py index 53c21da9cf..8a01a79ce2 100644 --- a/tests/jax/test_distributed_permutation.py +++ b/tests/jax/test_distributed_permutation.py @@ -43,6 +43,7 @@ def _inject_permutation(request): return import sys from transformer_engine.jax.permutation import token_dispatch, token_combine + mod = sys.modules[__name__] mod.token_dispatch = token_dispatch mod.token_combine = token_combine diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index c45efd9c56..cd30402585 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -46,6 +46,7 @@ def _inject_router(request): fused_topk_with_score_function, fused_moe_aux_loss, ) + mod = sys.modules[__name__] mod.fused_topk_with_score_function = fused_topk_with_score_function mod.fused_moe_aux_loss = fused_moe_aux_loss diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index a9d0fdcc40..1428c1f863 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -25,11 +25,13 @@ def _inject_router(request): fused_topk_with_score_function, fused_moe_aux_loss, ) + mod = sys.modules[__name__] mod.fused_topk_with_score_function = fused_topk_with_score_function mod.fused_moe_aux_loss = fused_moe_aux_loss yield + # ============================================================================= # Test case definitions (L0 = fast smoke, L2 = comprehensive) # ============================================================================= diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index f1e656f619..74158a93e9 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -25,6 +25,7 @@ def _inject_permutation(request): token_combine, sort_chunks_by_index, ) + mod = sys.modules[__name__] mod.token_dispatch = token_dispatch mod.token_combine = token_combine diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index f9b3376c0c..4453ab8098 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -172,7 +172,7 @@ def _check_triton_compatibility(): raise RuntimeError( f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for " "transformer_engine.jax.triton_extensions. " - f"Triton kernel dispatch segfaults with older jaxlib. " + "Triton kernel dispatch segfaults with older jaxlib. " f"Current jax version: {jax.__version__}. " "Please upgrade: pip install --upgrade jax jaxlib. " "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." From c19e9f5b2dfd01c14cc175b6cb0264f3248b7f65 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 9 Mar 2026 16:47:03 -0700 Subject: [PATCH 6/7] Address review: allow_module_level, drop is_triton_extension_supported re-export, revert test.sh - require_triton(): add allow_module_level=True to pytest.skip() so module-level calls on old JAX produce a proper skip instead of a collection failure - Remove is_triton_extension_supported from triton_extensions/utils.py __all__: importing triton_extensions on JAX < 0.8.0 raises immediately, so re-exporting the check from there defeats its purpose; callers should import directly from transformer_engine.jax.version_utils - Revert qa/L0_jax_lint/test.sh TE_PATH to /opt/transformerengine (local dev path was accidentally committed; pass TE_PATH= at invocation time instead) Signed-off-by: tdophung --- qa/L0_jax_lint/test.sh | 2 +- tests/jax/utils.py | 3 ++- transformer_engine/jax/triton_extensions/utils.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh index 9014ea90e9..3f804d3ef9 100755 --- a/qa/L0_jax_lint/test.sh +++ b/qa/L0_jax_lint/test.sh @@ -4,7 +4,7 @@ set -e -: "${TE_PATH:=/Users/tdophung/Repos/TransformerEngine}" +: "${TE_PATH:=/opt/transformerengine}" pip3 install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] diff --git a/tests/jax/utils.py b/tests/jax/utils.py index d412796cba..f592024a77 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -50,7 +50,8 @@ def require_triton(): pytest.skip( f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. " "Triton kernel dispatch segfaults with older jaxlib. " - "Upgrade with: pip install --upgrade jax jaxlib" + "Upgrade with: pip install --upgrade jax jaxlib", + allow_module_level=True, ) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 4453ab8098..6f7d40c587 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -179,7 +179,7 @@ def _check_triton_compatibility(): ) -__all__ = ["triton_call_lowering", "get_triton_info", "is_triton_extension_supported"] +__all__ = ["triton_call_lowering", "get_triton_info"] # Triton kernel cache (module-level, shared across all kernels) _TRITON_KERNEL_CACHE = {} From 93f342e66a02b17601d14a4f14f9af4213e96d75 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 9 Mar 2026 17:06:06 -0700 Subject: [PATCH 7/7] Address review: move version guard before gpu_triton import, fix __all__ and hardcoded version - Move is_triton_extension_supported() guard before the gpu_triton import block with a comment clarifying the segfault is at dispatch time, not import time - Remove _jax_version_meet_requirement from version_utils __all__ (private helper, not a public API; callers import it explicitly as needed) - Use TRITON_EXTENSION_MIN_JAX_VERSION constant in conftest marker description instead of hardcoded '0.8.0' Signed-off-by: tdophung --- tests/jax/conftest.py | 3 ++- .../jax/triton_extensions/utils.py | 26 +++++++++++-------- transformer_engine/jax/version_utils.py | 1 - 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 22e1622634..db30f0ed39 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -89,7 +89,8 @@ def pytest_sessionfinish(self, session, exitstatus): def pytest_configure(config): config.addinivalue_line( "markers", - "triton: mark test (or test class) as requiring JAX Triton kernel support (JAX >= 0.8.0)." + "triton: mark test (or test class) as requiring JAX Triton kernel support" + f" (JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION})." " Apply per test/class with @pytest.mark.triton so non-Triton tests in the same file run on" " old JAX.", ) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 6f7d40c587..28e3f08e18 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -155,6 +155,21 @@ def _check_triton_compatibility(): # Perform compatibility check and get triton info _TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility() +# Enforce minimum JAX version before importing gpu_triton. The segfault on old +# jaxlib occurs at Triton kernel dispatch time, not at import time, so gpu_triton +# itself is safe to import on older jaxlib. The guard is placed here (before the +# import) as a belt-and-suspenders measure so that if the import behaviour ever +# changes, we still fail fast with a clear error rather than a cryptic crash. +if not is_triton_extension_supported(): + raise RuntimeError( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for " + "transformer_engine.jax.triton_extensions. " + "Triton kernel dispatch segfaults with older jaxlib. " + f"Current jax version: {jax.__version__}. " + "Please upgrade: pip install --upgrade jax jaxlib. " + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc @@ -167,17 +182,6 @@ def _check_triton_compatibility(): "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) from e -# Enforce minimum JAX version for Triton kernel dispatch (segfaults on jaxlib < 0.8.0). -if not is_triton_extension_supported(): - raise RuntimeError( - f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for " - "transformer_engine.jax.triton_extensions. " - "Triton kernel dispatch segfaults with older jaxlib. " - f"Current jax version: {jax.__version__}. " - "Please upgrade: pip install --upgrade jax jaxlib. " - "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." - ) - __all__ = ["triton_call_lowering", "get_triton_info"] diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py index 60997b44d2..2431b20887 100644 --- a/transformer_engine/jax/version_utils.py +++ b/transformer_engine/jax/version_utils.py @@ -37,7 +37,6 @@ def is_triton_extension_supported() -> bool: __all__ = [ - "_jax_version_meet_requirement", "is_triton_extension_supported", "TRITON_EXTENSION_MIN_JAX_VERSION", ]