Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/examples/jax/attention.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

JAX: Attention with TransformerEngine
=====================================

**TODO — Coming soon.**
Comment thread
jberchtold-nvidia marked this conversation as resolved.

`← Back to the JAX integration overview <../te_jax_integration.html>`_
Comment on lines +1 to +11
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated to examples/jax

11 changes: 11 additions & 0 deletions docs/examples/jax/collective_gemm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

JAX: Collective GEMMs with TransformerEngine
=============================================

**TODO — Coming soon.**

`← Back to the JAX integration overview <../te_jax_integration.html>`_
21 changes: 21 additions & 0 deletions docs/examples/jax/dense.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Numbers below are illustrative (captured on a GB200). Regenerate with:
# python3 docs/examples/jax/dense.py > dense.out

# SINGLE_GPU_OUTPUT_START
Variable collections: ['params']
{'params': {'Dense_0': {'kernel': ((8192, 32768), dtype('float32'))}}}

bf16 baseline:
Mean time: 18.056 ms

TE MXFP8BlockScaling:
Mean time: 11.260 ms
# SINGLE_GPU_OUTPUT_END

# MULTI_GPU_OUTPUT_START
bf16 DP=2/TP=2:
Mean time: 5.516 ms

TE MXFP8BlockScaling DP=2/TP=2:
Mean time: 3.712 ms
# MULTI_GPU_OUTPUT_END
180 changes: 180 additions & 0 deletions docs/examples/jax/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""JAX: Dense GEMMs with TransformerEngine.

Companion source for ``dense.rst``. Code blocks between ``# DENSE_*_START`` /
``# DENSE_*_END`` markers are pulled into the RST via ``literalinclude``.

Run as a script to exercise the example end-to-end:

python docs/examples/jax/dense.py

Pytest tests live in ``test_dense.py``; the multi-GPU section auto-skips when
fewer than 4 GPUs are visible.
"""

# DENSE_IMPORTS_START
import jax
import jax.numpy as jnp
from flax import linen as nn

import quickstart_jax_utils as utils

# DENSE_IMPORTS_END


# DENSE_BASELINE_MODEL_START
class FlaxDenseBlock(nn.Module):
"""One linear layer. ``dot_general_cls`` lets us swap the GEMM impl."""

features: int
dtype: jnp.dtype = jnp.bfloat16
dot_general_cls: callable = lambda: None

@nn.compact
def __call__(self, x):
return nn.Dense(
features=self.features,
use_bias=False,
dtype=self.dtype,
dot_general=self.dot_general_cls(),
)(x)


# DENSE_BASELINE_MODEL_END


# DENSE_INPUTS_SETUP_START
batch, seq, hidden, out_features = 8, 2048, 8192, 32768
dtype = jnp.bfloat16

key = jax.random.PRNGKey(0)
k_init, k_x, k_dy = jax.random.split(key, 3)
x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype)
dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype)

baseline = FlaxDenseBlock(features=out_features)
baseline_vars = baseline.init(k_init, x)
# DENSE_INPUTS_SETUP_END


# DENSE_TE_SETUP_START
from transformer_engine.jax import flax as te_flax
from transformer_engine.common.recipe import MXFP8BlockScaling

recipe = MXFP8BlockScaling()
te_dot_general_cls = te_flax.make_dot_general_cls(recipe)

te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls)
te_vars = te_model.init(k_init, x)

print("Variable collections:", list(te_vars.keys()))
print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))
# DENSE_TE_SETUP_END


# DENSE_SINGLE_GPU_BENCH_START
def run_single_gpu_bench():
print("bf16 baseline:")
utils.speedometer(
model_apply_fn=baseline.apply,
variables=baseline_vars,
input=x,
output_grad=dy,
)

print(f"\nTE {type(recipe).__name__}:")
utils.speedometer(
model_apply_fn=te_model.apply,
variables=te_vars,
input=x,
output_grad=dy,
Comment thread
jberchtold-nvidia marked this conversation as resolved.
)


# DENSE_SINGLE_GPU_BENCH_END


# DENSE_MULTI_GPU_MESH_SETUP_START
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
from transformer_engine.jax.sharding import MeshResource, global_shard_guard


def build_dp_tp_mesh():
# 2x2 mesh: DP on one axis, TP on the other.
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=("dp", "tp"))

# Tell TE which mesh axis is which. This is a *global* setting, established
# outside JIT, so TE's GEMM primitives can plan comms accordingly.
mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp")
return mesh, mesh_resource


# DENSE_MULTI_GPU_MESH_SETUP_END


# DENSE_MULTI_GPU_SHARD_SETUP_START
def shard_variables(mesh, variables_dict):
kernel_sharding = NamedSharding(mesh, P(None, "tp"))

def _shard(variables):
params = variables["params"]
sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding)
return {
**variables,
"params": {
**params,
"Dense_0": {**params["Dense_0"], "kernel": sharded},
},
}

input_sharding = NamedSharding(mesh, P("dp", None, None))
output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp"))

return {
"x": jax.device_put(x, input_sharding),
"dy": jax.device_put(dy, output_grad_sharding),
**{name: _shard(vars_) for name, vars_ in variables_dict.items()},
}


# DENSE_MULTI_GPU_SHARD_SETUP_END


# DENSE_MULTI_GPU_BENCH_START
def run_multi_gpu_bench():
mesh, mesh_resource = build_dp_tp_mesh()
sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars})

with jax.set_mesh(mesh), global_shard_guard(mesh_resource):
print("bf16 DP=2/TP=2:")
utils.speedometer(
model_apply_fn=baseline.apply,
variables=sharded["baseline"],
input=sharded["x"],
output_grad=sharded["dy"],
)

print(f"\nTE {type(recipe).__name__} DP=2/TP=2:")
utils.speedometer(
model_apply_fn=te_model.apply,
variables=sharded["te"],
input=sharded["x"],
output_grad=sharded["dy"],
)


# DENSE_MULTI_GPU_BENCH_END


if __name__ == "__main__":
run_single_gpu_bench()
if len(jax.devices()) >= 4:
print()
run_multi_gpu_bench()
else:
print("\n[skipped multi-GPU section: <4 devices visible]")
168 changes: 168 additions & 0 deletions docs/examples/jax/dense.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

JAX: Dense GEMMs with TransformerEngine
=======================================

This document walks through replacing a plain ``flax.linen.Dense``'s GEMM with
TransformerEngine's quantized GEMM.

**Recipe.** We use ``MXFP8BlockScaling`` in this tutorial. ``MXFP8BlockScaling`` and
``NVFP4BlockScaling`` require a Blackwell-class GPU; on Hopper, swap in
``DelayedScaling`` or ``Float8CurrentScaling``.

`← Back to the JAX integration overview <../te_jax_integration.html>`_

1. Baseline: a plain Flax Dense block
-------------------------------------

We isolate the optimization to a single linear layer so it's clear what's
changing. ``dot_general_cls`` is exposed as a constructor argument so we can swap
in TE later without touching the model definition.

.. literalinclude:: dense.py
:language: python
:start-after: # DENSE_BASELINE_MODEL_START
:end-before: # DENSE_BASELINE_MODEL_END

.. literalinclude:: dense.py
:language: python
:start-after: # DENSE_INPUTS_SETUP_START
:end-before: # DENSE_INPUTS_SETUP_END


2. Quantized Dense via ``make_dot_general_cls``
-----------------------------------------------

TE exposes a helper, ``te_flax.make_dot_general_cls(recipe)``, that returns a Flax
module class you pass directly to ``nn.Dense(..., dot_general=...)``.

With this API, TE doesn't create the ``kernel`` params; it only wraps the GEMM.
All your initialization, sharding annotations, and optimizer state stay where
they were.

.. literalinclude:: dense.py
:language: python
:start-after: # DENSE_TE_SETUP_START
:end-before: # DENSE_TE_SETUP_END

.. note::

**What about DelayedScaling state?**

Most recipes are stateless — scaling factors are computed from each tensor
as it flows through the GEMM, so there is nothing to persist across steps.
However, if you swap in ``DelayedScaling`` instead, ``init`` will produce a
second variable collection, ``_overwrite_with_gradient``, holding
``kernel_amax_history``, ``kernel_scale``, ``x_amax_history``, ``x_scale``,
etc. These are **not** model parameters — they are Flax variables that TE
updates each step to compute per-tensor scales from a rolling amax window.

If you use ``DelayedScaling``, you must thread the *entire* ``var_collect``
through your training loop (not just ``params``) so the history persists
across steps. ``MXFP8BlockScaling``, ``NVFP4BlockScaling``, and
``Float8CurrentScaling`` do not require this.


3. Single-GPU performance
-------------------------

``speedometer`` runs a JIT-compiled forward+backward loop with warmup, on the
same input for both models.

.. literalinclude:: dense.py
:language: python
:start-after: # DENSE_SINGLE_GPU_BENCH_START
:end-before: # DENSE_SINGLE_GPU_BENCH_END

.. raw:: html

<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>

.. container:: program-output

.. literalinclude:: dense.out
:language: text
:start-after: # SINGLE_GPU_OUTPUT_START
:end-before: # SINGLE_GPU_OUTPUT_END

On a single GB200, that's roughly **2.5× faster** for the fwd+bwd of one large
Dense — and the only code change was passing ``dot_general=te_dot_general_cls()``
into ``nn.Dense``.

The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may
not benefit at all because the cast + scale overhead can dominate.

.. warning::

**Remat / activation checkpointing.** If your training loop uses
``jax.checkpoint_policies.checkpoint_dots`` (or any policy that matches
``jax.lax.dot_general``), swap it for
``transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms``.
Otherwise TE's quantized GEMM primitives won't be checkpointed correctly
and your performance comparison will not be accurate.


4. Multi-GPU: DP=2 / TP=2 on a single Dense
-------------------------------------------

**Prerequisite:** this section requires four GPUs.

Keeping the same ``FlaxDenseBlock`` from the rest of the document, we run it on
a 2×2 mesh with **data parallelism** on one axis and **tensor parallelism**
(column-parallel: shard the kernel's output dim) on the other.

Two pieces wire this up:

1. A ``jax.sharding.Mesh`` you build once at module scope (outside JIT).
2. TE's ``MeshResource``, set globally via ``global_shard_guard``, which tells
TE which mesh axes are DP and TP.

.. literalinclude:: dense.py
:language: python
:start-after: # DENSE_MULTI_GPU_MESH_SETUP_START
:end-before: # DENSE_MULTI_GPU_MESH_SETUP_END

**Sharding plan:**

.. csv-table::
:header: "Tensor", "Shape", "PartitionSpec"
:widths: 30, 40, 30

"Kernel (column-parallel)", "``(hidden, out_features)``", "``P(None, 'tp')``"
"Input activations", "``(batch, seq, hidden)``", "``P('dp', None, None)``"
"Gradient on output", "``(batch, seq, out_features)``", "``P('dp', None, 'tp')``"

.. literalinclude:: dense.py
:language: python
:start-after: # DENSE_MULTI_GPU_SHARD_SETUP_START
:end-before: # DENSE_MULTI_GPU_SHARD_SETUP_END

.. literalinclude:: dense.py
:language: python
:start-after: # DENSE_MULTI_GPU_BENCH_START
:end-before: # DENSE_MULTI_GPU_BENCH_END

.. raw:: html

<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>

.. container:: program-output

.. literalinclude:: dense.out
:language: text
:start-after: # MULTI_GPU_OUTPUT_START
:end-before: # MULTI_GPU_OUTPUT_END


Next steps
----------

* `Collective GEMM <collective_gemm.html>`_: further speedups by communicating between devices inside the GEMM.
* `← Hub <../te_jax_integration.html>`_
Loading
Loading