Skip to content
Empty file modified qa/L0_jax_lint/test.sh
100644 → 100755
Empty file.
27 changes: 27 additions & 0 deletions tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

import transformer_engine.jax
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.jax.version_utils import (
TRITON_EXTENSION_MIN_JAX_VERSION,
is_triton_extension_supported,
)


@pytest.fixture(autouse=True, scope="function")
Expand Down Expand Up @@ -83,5 +87,28 @@ def pytest_sessionfinish(self, session, exitstatus):


def pytest_configure(config):
config.addinivalue_line(
"markers",
"triton: mark test (or test class) as requiring JAX Triton kernel support"
f" (JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION})."
" Apply per test/class with @pytest.mark.triton so non-Triton tests in the same file run on"
" old JAX.",
)
if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1":
config.pluginmanager.register(TestTimingPlugin(), "test_timing")


def pytest_collection_modifyitems(config, items):
"""Skip tests marked 'triton' when JAX is too old for Triton kernel dispatch."""
if is_triton_extension_supported():
return
skip_triton = pytest.mark.skip(
reason=(
f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. "
"Triton kernel dispatch segfaults with older jaxlib. "
"Upgrade with: pip install --upgrade jax jaxlib"
)
)
for item in items:
if item.get_closest_marker("triton"):
item.add_marker(skip_triton)
23 changes: 18 additions & 5 deletions tests/jax/test_distributed_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,23 @@
from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper

# High-level API with VJP support
from transformer_engine.jax.permutation import (
token_dispatch,
token_combine,
)

@pytest.fixture(autouse=True, scope="function")
def _inject_permutation(request):
"""Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing."""
if not request.node.get_closest_marker("triton"):
yield
return
import sys
from transformer_engine.jax.permutation import token_dispatch, token_combine

mod = sys.modules[__name__]
mod.token_dispatch = token_dispatch
mod.token_combine = token_combine
yield


# High-level API with VJP support (injected by _inject_permutation)

# Reference implementations from test_permutation.py
from test_permutation import (
Expand Down Expand Up @@ -80,6 +92,7 @@
}


@pytest.mark.triton
class TestDistributedPermutation:
"""Test distributed/sharded execution of MoE permutation primitives.

Expand Down
25 changes: 21 additions & 4 deletions tests/jax/test_distributed_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,24 @@
from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper

from transformer_engine.jax.router import (
fused_topk_with_score_function,
fused_moe_aux_loss,
)

@pytest.fixture(autouse=True, scope="function")
def _inject_router(request):
"""Lazy-load router API only for tests marked 'triton'. Other tests run without importing."""
if not request.node.get_closest_marker("triton"):
yield
return
import sys
from transformer_engine.jax.router import (
fused_topk_with_score_function,
fused_moe_aux_loss,
)

mod = sys.modules[__name__]
mod.fused_topk_with_score_function = fused_topk_with_score_function
mod.fused_moe_aux_loss = fused_moe_aux_loss
yield


jax.config.update("jax_use_shardy_partitioner", True)

Expand Down Expand Up @@ -68,6 +82,7 @@
}


@pytest.mark.triton
class TestDistributedFusedTopk:
"""Test distributed execution of fused_topk_with_score_function.

Expand Down Expand Up @@ -200,6 +215,7 @@ def test_distributed_topk(
)


@pytest.mark.triton
class TestDistributedScoreForAuxLoss:
"""Test distributed execution of fused_topk_with_score_function with compute_aux_scores=True.

Expand Down Expand Up @@ -333,6 +349,7 @@ def test_distributed_score_for_aux_loss(
)


@pytest.mark.triton
class TestDistributedMoEAuxLoss:
"""Test distributed execution of fused_moe_aux_loss.

Expand Down
26 changes: 22 additions & 4 deletions tests/jax/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Tests for fused MoE router CUDA kernels (JAX wrappers)."""

import sys
from functools import partial
from typing import Optional

Expand All @@ -13,10 +14,23 @@

from utils import pytest_parametrize_wrapper

from transformer_engine.jax.router import (
fused_topk_with_score_function,
fused_moe_aux_loss,
)

@pytest.fixture(autouse=True, scope="function")
def _inject_router(request):
"""Lazy-load router API only for tests marked 'triton'. Other tests run without importing."""
if not request.node.get_closest_marker("triton"):
yield
return
from transformer_engine.jax.router import (
fused_topk_with_score_function,
fused_moe_aux_loss,
)

mod = sys.modules[__name__]
mod.fused_topk_with_score_function = fused_topk_with_score_function
mod.fused_moe_aux_loss = fused_moe_aux_loss
yield


# =============================================================================
# Test case definitions (L0 = fast smoke, L2 = comprehensive)
Expand Down Expand Up @@ -371,6 +385,7 @@ def loss_fused(logits_):
@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS)
@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS)
@pytest_parametrize_wrapper("enable_bias", ENABLE_BIAS_OPTIONS)
@pytest.mark.triton
def test_topk_sigmoid(
dtype, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias
):
Expand All @@ -397,6 +412,7 @@ def test_topk_sigmoid(
@pytest_parametrize_wrapper("use_pre_softmax", USE_PRE_SOFTMAX_OPTIONS)
@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS)
@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS)
@pytest.mark.triton
def test_topk_softmax(
dtype, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor
):
Expand Down Expand Up @@ -426,6 +442,7 @@ def test_topk_softmax(
SCORE_AUX_LOSS_CASES,
)
@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS)
@pytest.mark.triton
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
logits = make_logits(num_tokens, num_experts, score_function, dtype)

Expand Down Expand Up @@ -486,6 +503,7 @@ def loss_fused(logits_):
"num_tokens,num_experts,topk",
AUX_LOSS_CASES,
)
@pytest.mark.triton
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
key = jax.random.PRNGKey(SEED)

Expand Down
27 changes: 21 additions & 6 deletions tests/jax/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,34 @@
"""Tests for permutation Triton kernels and high-level APIs"""

import functools
import sys

import jax
import jax.numpy as jnp
import pytest

# High-level API with VJP support
from transformer_engine.jax.permutation import (
token_dispatch,
token_combine,
sort_chunks_by_index,
)
from utils import assert_allclose, pytest_parametrize_wrapper


@pytest.fixture(autouse=True, scope="function")
def _inject_permutation(request):
"""Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing."""
if not request.node.get_closest_marker("triton"):
yield
return
from transformer_engine.jax.permutation import (
token_dispatch,
token_combine,
sort_chunks_by_index,
)

mod = sys.modules[__name__]
mod.token_dispatch = token_dispatch
mod.token_combine = token_combine
mod.sort_chunks_by_index = sort_chunks_by_index
yield


ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3),
(1024, 8, 128, 8),
Expand Down Expand Up @@ -449,6 +463,7 @@ def reference_sort_chunks_by_map(
return output, permuted_probs


@pytest.mark.triton
class TestHighLevelPermutationAPI:
"""Test high-level permutation APIs (token_dispatch, token_combine, etc.)

Expand Down
5 changes: 4 additions & 1 deletion tests/jax/test_triton_custom_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import jax.numpy as jnp
import pytest

from utils import assert_allclose, pytest_parametrize_wrapper
from utils import assert_allclose, pytest_parametrize_wrapper, require_triton

require_triton()

import triton
import triton.language as tl
Expand All @@ -23,6 +25,7 @@ def init():
yield


@pytest.mark.triton
class TestTritonBinding:
"""Test Triton binding primitive."""

Expand Down
15 changes: 15 additions & 0 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
make_swa_mask,
)
from transformer_engine.jax.quantize.helper import DType as TEDType
from transformer_engine.jax.version_utils import (
TRITON_EXTENSION_MIN_JAX_VERSION,
is_triton_extension_supported,
)

PRNGKey = Any
Shape = Tuple[int, ...]
Expand All @@ -40,6 +44,17 @@
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))


def require_triton():
"""Skip the current test module if JAX is too old for Triton kernel support."""
if not is_triton_extension_supported():
pytest.skip(
f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. "
"Triton kernel dispatch segfaults with older jaxlib. "
"Upgrade with: pip install --upgrade jax jaxlib",
allow_module_level=True,
)


def is_devices_enough(required):
"""
Check if the available GPUs is enough
Expand Down
15 changes: 2 additions & 13 deletions transformer_engine/jax/quantize/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
from enum import Enum
import hashlib
from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce, lru_cache
from functools import reduce
import operator
from importlib.metadata import version as get_pkg_version
import warnings
from packaging.version import Version as PkgVersion

import jax
import jax.numpy as jnp
Expand All @@ -40,6 +38,7 @@
get_all_mesh_axes,
with_sharding_constraint,
)
from transformer_engine.jax.version_utils import _jax_version_meet_requirement

from .metadata import QuantizeMeta
from .scaling_modes import ScalingMode
Expand Down Expand Up @@ -68,16 +67,6 @@
NVTE_FP8_COLLECTION_NAME = "fp8_metas"


@lru_cache(maxsize=None)
def _jax_version_meet_requirement(version: str):
"""
Helper function checking if required JAX version is available
"""
jax_version = PkgVersion(get_pkg_version("jax"))
jax_version_required = PkgVersion(version)
return jax_version >= jax_version_required


def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture.

Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/jax/triton_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def lowering(ctx, x, **kwargs):
from transformer_engine.jax.triton_extensions import get_triton_info
info = get_triton_info()
print(f"Using Triton {info['version']} from {info['source']}")

# Check if JAX version supports Triton (without importing triton_extensions)
from transformer_engine.jax.version_utils import is_triton_extension_supported
"""

from .utils import *
Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
import jax
import jax.numpy as jnp

from ..version_utils import (
TRITON_EXTENSION_MIN_JAX_VERSION,
is_triton_extension_supported,
)


# Placeholder package version on PyPI that should never be used
_PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1"
Expand Down Expand Up @@ -150,6 +155,21 @@ def _check_triton_compatibility():
# Perform compatibility check and get triton info
_TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility()

# Enforce minimum JAX version before importing gpu_triton. The segfault on old
# jaxlib occurs at Triton kernel dispatch time, not at import time, so gpu_triton
# itself is safe to import on older jaxlib. The guard is placed here (before the
# import) as a belt-and-suspenders measure so that if the import behaviour ever
# changes, we still fail fast with a clear error rather than a cryptic crash.
if not is_triton_extension_supported():
raise RuntimeError(
f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for "
"transformer_engine.jax.triton_extensions. "
"Triton kernel dispatch segfaults with older jaxlib. "
f"Current jax version: {jax.__version__}. "
"Please upgrade: pip install --upgrade jax jaxlib. "
"If you don't need Triton, use transformer_engine.jax.cpp_extensions instead."
)

try:
from jax._src.lib import gpu_triton
from triton.compiler import compiler as tc
Expand Down
Loading
Loading