diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh old mode 100644 new mode 100755 diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 6b7520d147..db30f0ed39 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -11,6 +11,10 @@ import transformer_engine.jax from transformer_engine_jax import get_device_compute_capability +from transformer_engine.jax.version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) @pytest.fixture(autouse=True, scope="function") @@ -83,5 +87,28 @@ def pytest_sessionfinish(self, session, exitstatus): def pytest_configure(config): + config.addinivalue_line( + "markers", + "triton: mark test (or test class) as requiring JAX Triton kernel support" + f" (JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION})." + " Apply per test/class with @pytest.mark.triton so non-Triton tests in the same file run on" + " old JAX.", + ) if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1": config.pluginmanager.register(TestTimingPlugin(), "test_timing") + + +def pytest_collection_modifyitems(config, items): + """Skip tests marked 'triton' when JAX is too old for Triton kernel dispatch.""" + if is_triton_extension_supported(): + return + skip_triton = pytest.mark.skip( + reason=( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. " + "Triton kernel dispatch segfaults with older jaxlib. " + "Upgrade with: pip install --upgrade jax jaxlib" + ) + ) + for item in items: + if item.get_closest_marker("triton"): + item.add_marker(skip_triton) diff --git a/tests/jax/test_distributed_permutation.py b/tests/jax/test_distributed_permutation.py index 04ed236e81..8a01a79ce2 100644 --- a/tests/jax/test_distributed_permutation.py +++ b/tests/jax/test_distributed_permutation.py @@ -34,11 +34,23 @@ from distributed_test_base import generate_configs from utils import assert_allclose, pytest_parametrize_wrapper -# High-level API with VJP support -from transformer_engine.jax.permutation import ( - token_dispatch, - token_combine, -) + +@pytest.fixture(autouse=True, scope="function") +def _inject_permutation(request): + """Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + import sys + from transformer_engine.jax.permutation import token_dispatch, token_combine + + mod = sys.modules[__name__] + mod.token_dispatch = token_dispatch + mod.token_combine = token_combine + yield + + +# High-level API with VJP support (injected by _inject_permutation) # Reference implementations from test_permutation.py from test_permutation import ( @@ -80,6 +92,7 @@ } +@pytest.mark.triton class TestDistributedPermutation: """Test distributed/sharded execution of MoE permutation primitives. diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index 1b3fe14e75..cd30402585 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -34,10 +34,24 @@ from distributed_test_base import generate_configs from utils import assert_allclose, pytest_parametrize_wrapper -from transformer_engine.jax.router import ( - fused_topk_with_score_function, - fused_moe_aux_loss, -) + +@pytest.fixture(autouse=True, scope="function") +def _inject_router(request): + """Lazy-load router API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + import sys + from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, + ) + + mod = sys.modules[__name__] + mod.fused_topk_with_score_function = fused_topk_with_score_function + mod.fused_moe_aux_loss = fused_moe_aux_loss + yield + jax.config.update("jax_use_shardy_partitioner", True) @@ -68,6 +82,7 @@ } +@pytest.mark.triton class TestDistributedFusedTopk: """Test distributed execution of fused_topk_with_score_function. @@ -200,6 +215,7 @@ def test_distributed_topk( ) +@pytest.mark.triton class TestDistributedScoreForAuxLoss: """Test distributed execution of fused_topk_with_score_function with compute_aux_scores=True. @@ -333,6 +349,7 @@ def test_distributed_score_for_aux_loss( ) +@pytest.mark.triton class TestDistributedMoEAuxLoss: """Test distributed execution of fused_moe_aux_loss. diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index 77e89457c8..1428c1f863 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -4,6 +4,7 @@ """Tests for fused MoE router CUDA kernels (JAX wrappers).""" +import sys from functools import partial from typing import Optional @@ -13,10 +14,23 @@ from utils import pytest_parametrize_wrapper -from transformer_engine.jax.router import ( - fused_topk_with_score_function, - fused_moe_aux_loss, -) + +@pytest.fixture(autouse=True, scope="function") +def _inject_router(request): + """Lazy-load router API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, + ) + + mod = sys.modules[__name__] + mod.fused_topk_with_score_function = fused_topk_with_score_function + mod.fused_moe_aux_loss = fused_moe_aux_loss + yield + # ============================================================================= # Test case definitions (L0 = fast smoke, L2 = comprehensive) @@ -371,6 +385,7 @@ def loss_fused(logits_): @pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) @pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) @pytest_parametrize_wrapper("enable_bias", ENABLE_BIAS_OPTIONS) +@pytest.mark.triton def test_topk_sigmoid( dtype, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias ): @@ -397,6 +412,7 @@ def test_topk_sigmoid( @pytest_parametrize_wrapper("use_pre_softmax", USE_PRE_SOFTMAX_OPTIONS) @pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) @pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +@pytest.mark.triton def test_topk_softmax( dtype, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor ): @@ -426,6 +442,7 @@ def test_topk_softmax( SCORE_AUX_LOSS_CASES, ) @pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) +@pytest.mark.triton def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): logits = make_logits(num_tokens, num_experts, score_function, dtype) @@ -486,6 +503,7 @@ def loss_fused(logits_): "num_tokens,num_experts,topk", AUX_LOSS_CASES, ) +@pytest.mark.triton def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): key = jax.random.PRNGKey(SEED) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 138a817240..74158a93e9 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -5,20 +5,34 @@ """Tests for permutation Triton kernels and high-level APIs""" import functools +import sys import jax import jax.numpy as jnp import pytest -# High-level API with VJP support -from transformer_engine.jax.permutation import ( - token_dispatch, - token_combine, - sort_chunks_by_index, -) from utils import assert_allclose, pytest_parametrize_wrapper +@pytest.fixture(autouse=True, scope="function") +def _inject_permutation(request): + """Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing.""" + if not request.node.get_closest_marker("triton"): + yield + return + from transformer_engine.jax.permutation import ( + token_dispatch, + token_combine, + sort_chunks_by_index, + ) + + mod = sys.modules[__name__] + mod.token_dispatch = token_dispatch + mod.token_combine = token_combine + mod.sort_chunks_by_index = sort_chunks_by_index + yield + + ALL_DISPATCH_COMBINE_CASES = [ (128, 5, 128, 3), (1024, 8, 128, 8), @@ -449,6 +463,7 @@ def reference_sort_chunks_by_map( return output, permuted_probs +@pytest.mark.triton class TestHighLevelPermutationAPI: """Test high-level permutation APIs (token_dispatch, token_combine, etc.) diff --git a/tests/jax/test_triton_custom_calls.py b/tests/jax/test_triton_custom_calls.py index 6d969de0d3..139c6b91a6 100644 --- a/tests/jax/test_triton_custom_calls.py +++ b/tests/jax/test_triton_custom_calls.py @@ -7,7 +7,9 @@ import jax.numpy as jnp import pytest -from utils import assert_allclose, pytest_parametrize_wrapper +from utils import assert_allclose, pytest_parametrize_wrapper, require_triton + +require_triton() import triton import triton.language as tl @@ -23,6 +25,7 @@ def init(): yield +@pytest.mark.triton class TestTritonBinding: """Test Triton binding primitive.""" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c22b0a6063..f592024a77 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -26,6 +26,10 @@ make_swa_mask, ) from transformer_engine.jax.quantize.helper import DType as TEDType +from transformer_engine.jax.version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) PRNGKey = Any Shape = Tuple[int, ...] @@ -40,6 +44,17 @@ NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0))) +def require_triton(): + """Skip the current test module if JAX is too old for Triton kernel support.""" + if not is_triton_extension_supported(): + pytest.skip( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. " + "Triton kernel dispatch segfaults with older jaxlib. " + "Upgrade with: pip install --upgrade jax jaxlib", + allow_module_level=True, + ) + + def is_devices_enough(required): """ Check if the available GPUs is enough diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 7d81d71bc8..2548ec9e46 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -14,11 +14,9 @@ from enum import Enum import hashlib from typing import Optional, Tuple, Dict, Union, Sequence, Type, List -from functools import reduce, lru_cache +from functools import reduce import operator -from importlib.metadata import version as get_pkg_version import warnings -from packaging.version import Version as PkgVersion import jax import jax.numpy as jnp @@ -40,6 +38,7 @@ get_all_mesh_axes, with_sharding_constraint, ) +from transformer_engine.jax.version_utils import _jax_version_meet_requirement from .metadata import QuantizeMeta from .scaling_modes import ScalingMode @@ -68,16 +67,6 @@ NVTE_FP8_COLLECTION_NAME = "fp8_metas" -@lru_cache(maxsize=None) -def _jax_version_meet_requirement(version: str): - """ - Helper function checking if required JAX version is available - """ - jax_version = PkgVersion(get_pkg_version("jax")) - jax_version_required = PkgVersion(version) - return jax_version >= jax_version_required - - def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index d9708fde9f..150a5fbf12 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -54,6 +54,9 @@ def lowering(ctx, x, **kwargs): from transformer_engine.jax.triton_extensions import get_triton_info info = get_triton_info() print(f"Using Triton {info['version']} from {info['source']}") + + # Check if JAX version supports Triton (without importing triton_extensions) + from transformer_engine.jax.version_utils import is_triton_extension_supported """ from .utils import * diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2627a08929..28e3f08e18 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -42,6 +42,11 @@ import jax import jax.numpy as jnp +from ..version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) + # Placeholder package version on PyPI that should never be used _PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1" @@ -150,6 +155,21 @@ def _check_triton_compatibility(): # Perform compatibility check and get triton info _TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility() +# Enforce minimum JAX version before importing gpu_triton. The segfault on old +# jaxlib occurs at Triton kernel dispatch time, not at import time, so gpu_triton +# itself is safe to import on older jaxlib. The guard is placed here (before the +# import) as a belt-and-suspenders measure so that if the import behaviour ever +# changes, we still fail fast with a clear error rather than a cryptic crash. +if not is_triton_extension_supported(): + raise RuntimeError( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for " + "transformer_engine.jax.triton_extensions. " + "Triton kernel dispatch segfaults with older jaxlib. " + f"Current jax version: {jax.__version__}. " + "Please upgrade: pip install --upgrade jax jaxlib. " + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py new file mode 100644 index 0000000000..2431b20887 --- /dev/null +++ b/transformer_engine/jax/version_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +JAX version helpers. + +Provides version checks for JAX that can be used across TE JAX (quantize, triton +extensions, etc.) without pulling in feature-specific code. +""" + +from functools import lru_cache +from importlib.metadata import version as get_pkg_version + +from packaging.version import Version as PkgVersion + + +@lru_cache(maxsize=None) +def _jax_version_meet_requirement(version: str): + """Return True if the installed JAX version is >= the required version.""" + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + +# Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults). +TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0" + + +def is_triton_extension_supported() -> bool: + """Return True if the current JAX version supports Triton kernel dispatch. + + JAX/jaxlib >= 0.8.0 is required. Older versions segfault when dispatching + Triton kernels. Use this to skip tests or gate features without importing + triton_extensions (which would raise immediately on old jax). + """ + return _jax_version_meet_requirement(TRITON_EXTENSION_MIN_JAX_VERSION) + + +__all__ = [ + "is_triton_extension_supported", + "TRITON_EXTENSION_MIN_JAX_VERSION", +]