Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
9003f0d
[OMNIML-5029] launcher: bind-mount Megatron-LM via megatron_install_path
hychiang-git Jun 9, 2026
4af99c7
[OMNIML-5030] tests: enable sequence_parallel for MoE+TP in test_moe_…
hychiang-git Jun 9, 2026
e3a6f5c
[OMNIML-4998] quantization: per-expert weight amax on TEGroupedLinear…
hychiang-git Jun 8, 2026
e8d9eea
[OMNIML-4998] quantization: dist-checkpoint round-trip for per-expert…
hychiang-git Jun 8, 2026
1e11b4b
[OMNIML-4998] quantization: gather per-expert amax before save, not d…
hychiang-git Jun 8, 2026
61c68d6
[OMNIML-4998] quantization: AC-2 smoke uses TP=2 EP=2 (known-good MCo…
hychiang-git Jun 9, 2026
6751331
[OMNIML-4998] quantization: mirror test_moe_sharded_state_dict setup …
hychiang-git Jun 9, 2026
8fd0d53
[OMNIML-4998] quantization: per-expert _amax shape on TEGrouped restore
hychiang-git Jun 10, 2026
c13ad0b
[OMNIML-4998] quantization: reshape loaded per-expert _amax at the le…
hychiang-git Jun 11, 2026
82b4979
[OMNIML-4998] quantization: TP-sync per-expert _amax in max_calibrate
hychiang-git Jun 11, 2026
d7ccf0a
[OMNIML-5072] quantization: use unbind(0) to unpack axis-0 per-expert…
hychiang-git Jun 11, 2026
26307c7
[OMNIML-5072] quantization: no-stack autograd.Function for per-expert…
hychiang-git Jun 12, 2026
4c164f9
[OMNIML-5072] quantization: fuse STE backward in _GroupedAxis0FakeQua…
hychiang-git Jun 12, 2026
e4945a8
[OMNIML-5072] revert AC5 no-stack Function — negative result on B300 …
hychiang-git Jun 12, 2026
0bf4838
[OMNIML-5072] Triton per-expert axis-0 fake-quant kernel for TEGroupe…
hychiang-git Jun 12, 2026
1080e68
[OMNIML-5072] Triton fwd: round-half-to-even via libdevice.rint
hychiang-git Jun 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/VALIDATION_TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# OMNIML-5072 Option B — Triton kernel validation record

**Status: VALIDATED 2026-06-12.** Kernel lives at `grouped_axis0_fakequant.py`,
wired into `te_grouped_quantized_linear_fn` via `_GroupedAxis0FakeQuantFn`.

## Validation summary

### 1. Numerical fidelity

Compared `grouped_axis0_fakequant(weights, amax)` against per-expert
`cuda_ext.fake_tensor_quant(w_i, amax_i)` (A's path) at the Ultra production
shape (N=32, [5120, 8192] bfloat16). See
`nmm-sandbox/studies/omniml-5064/microbench/parity_a_vs_btriton.py`.

```
Check max_abs_err Notes
───────────────────────────────── ─────────── ─────────────────────────────
forward output 0.03125 = 1 ULP at this quant scale
(rounding mode: Triton does
round-half-away, cuda_ext does
round-half-to-even; differs only
on exact half-step boundaries)
backward grad (pass_through_bwd) 0.0 bit-exact identical to A
```

### 2. Bench (OMNIML-5064 microbench, B300, aws-cmh)

Btriton beats A on every column at every cell:

```
Cell impl fwd_us bwd_us step_us Notes
───────────────────────────────────────────────────────────────────────────────
Nano EP=1 A 20,221 5,807 27,422
Btriton 2,670 3,591 7,472 fwd 7.6× win
Super EP=1 A 57,736 17,736 84,289
Btriton 9,584 14,923 33,242 fwd 6.0× win
Ultra mock16 A 5,835 547 8,246
Btriton 1,795 549 4,210 fwd 3.25× win; bwd tied
Nano EP=4 A 3,985 1,340 5,693
Btriton 1,208 1,221 2,815 fwd 3.30× win
Super EP=4 A 14,094 4,187 20,576
Btriton 2,110 3,326 7,710 fwd 6.68× win
Ultra EP=4 A 23,178 2,064 32,686
Btriton 6,730 2,059 16,220 fwd 3.44× win
Ultra EP=8 A 11,695 1,070 16,482 2-node
Btriton 3,515 1,075 8,314 fwd 3.33× win
```

### 3. Distributed validation

- ✓ EP=1 single-rank (mock-EP=16 emulation of EP=16 deployment)
- ✓ EP=4 single-node 4-rank (Nano / Super / Ultra)
- ✓ EP=8 multi-node 2-node 8-rank (Ultra)
- `peak_mb` matches A's within 3 MB across cells (same layer instantiation
signal)

### 4. Hardware coverage

Tested on B300 (compute 10.0+, aws-cmh). The kernel uses no SM-specific
intrinsics (no MMA, no async copy, no tensor cores) — pure load/store +
arithmetic in `triton.language`. Should compile and run on any GPU where
`torch.cuda.is_available()` returns True and Triton is installed; the
`IS_AVAILABLE` guard in `__init__.py` skips it otherwise.

## Design notes

The kernel takes N expert weights as a `[N]` int64 tensor of base pointers
(via each tensor's `.data_ptr()`). Each Triton program reads its expert's
pointer, then strides through a block of elements. Grid: `(N, num_blocks_per_expert)`.
This eliminates the `torch.stack` memcopy on the forward path (~2.7 GB at
Ultra scale).

Backward honors modelopt's `pass_through_bwd` flag (`config.py` default
`True`). When set, the backward returns `grad_outputs` unchanged with zero
kernel launches — matching `_fake_quant_backward_function`'s no-save
behavior. When `False`, the clip-aware STE Triton backward kernel runs.

Soft-gated at the call site in `te_grouped_quantized_linear_fn`: falls back
to the stack-then-quant-then-unbind path when the Triton kernels aren't
available or when calibration is active (`q._if_calib`).
4 changes: 4 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,8 @@
if torch.cuda.get_device_capability() >= (8, 9):
from .fp4_kernel_hopper import *

# OMNIML-5072 Option B — per-expert axis-0 fake-quant via tensor-of-pointers.
# Generic Triton + CUDA; no special hardware. See VALIDATION_TODO.md.
from .grouped_axis0_fakequant import *

IS_AVAILABLE = True
257 changes: 257 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/grouped_axis0_fakequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Fused per-expert axis-0 fake-quant Triton kernels for TEGroupedLinear.

Replaces the stack-then-quantize-then-unbind pattern in modelopt's TEGrouped
plugin (`te_grouped_quantized_linear_fn`) with a single Triton launch that
processes N expert weights in place, with no contiguous-tensor staging.

Design — tensor of pointers
---------------------------

The N expert weights live as separate Parameters (one per expert), so they're
NOT contiguous in HBM. To avoid a `torch.stack` memcopy (the cost AC5
characterized on OMNIML-5064), we feed the kernel a `[N]` int64 tensor of
expert base pointers. Each Triton program reads its expert's pointer first,
then strides through a block of elements at that address.

Grid: (N, num_blocks_per_expert).
Program 0 of axis 0 → expert 0, program 1 → expert 1, etc.

See OMNIML-5072 AC5 (Option B follow-up) for the motivation.

VALIDATION STATUS (2026-06-11): kernel implemented, numerical fidelity NOT
yet validated against modelopt's reference `fake_quant_impl`, and bench
performance NOT yet measured. See VALIDATION_TODO.md in this directory.
"""

from __future__ import annotations

import torch
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice

__all__ = ["grouped_axis0_fakequant", "grouped_axis0_fakequant_backward"]


def _torch_dtype_to_tl(dtype: torch.dtype):
"""Map a torch dtype to its Triton-language equivalent."""
return {
torch.float32: tl.float32,
torch.bfloat16: tl.bfloat16,
torch.float16: tl.float16,
}[dtype]


@triton.jit
def _grouped_axis0_fakequant_fwd_kernel(
weight_ptrs_buf, # int64 [N] — N expert base pointers (cast from .data_ptr())
output_ptrs_buf, # int64 [N] — N output base pointers
amax_vec_ptr, # [N, 1, 1] (or anything with N as the leading dim)
elements_per_expert,
num_bits,
narrow_range: tl.constexpr,
DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
expert_idx = tl.program_id(axis=0)
block_idx = tl.program_id(axis=1)

# Per-expert base pointers (loaded once per program).
w_int = tl.load(weight_ptrs_buf + expert_idx)
out_int = tl.load(output_ptrs_buf + expert_idx)
w_ptr = w_int.to(tl.pointer_type(DTYPE))
out_ptr = out_int.to(tl.pointer_type(DTYPE))

# Per-expert amax → quant scale.
# amax is stored as fp32; convert to working precision.
amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32)
# qmax = 2^(num_bits-1) - 1 when narrow_range else 2^(num_bits-1)
# For num_bits=8 narrow_range=True (modelopt default): qmax=127
qmax = ((1 << (num_bits - 1)) - 1) if narrow_range else (1 << (num_bits - 1))
qmin = -qmax if narrow_range else -qmax # signed symmetric
scale = amax / qmax

# Block of elements within this expert.
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < elements_per_expert

x = tl.load(w_ptr + offsets, mask=mask, other=0.0).to(tl.float32)

# Fake-quant: round(clip(x / scale)) * scale.
# Use scale guards to avoid div-by-zero before _amax is calibrated (early
# batches may carry an _amax of 0; matches modelopt's fake_tensor_quant
# behavior of passing through unchanged).
safe_scale = tl.where(scale > 0.0, scale, 1.0)
q = x / safe_scale
q = tl.maximum(tl.minimum(q, qmax), qmin)
# Round-half-to-even (banker's), matching cuda_ext.fake_tensor_quant exactly.
# libdevice.rint is CUDA's __rint* builtin. Imported via the same path that
# modelopt's nvfp4_quant.py uses (triton.language.extra.cuda.libdevice).
q_rounded = libdevice.rint(q)
out = tl.where(scale > 0.0, q_rounded * scale, x)

tl.store(out_ptr + offsets, out.to(DTYPE), mask=mask)


@triton.jit
def _grouped_axis0_fakequant_bwd_kernel(
weight_ptrs_buf, # int64 [N] — same buffer as fwd
grad_out_ptrs_buf, # int64 [N] — upstream grad pointers (per expert)
grad_in_ptrs_buf, # int64 [N] — output: downstream grad pointers
amax_vec_ptr, # [N, ...] — same buffer as fwd
elements_per_expert,
DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Clip-aware STE backward.

For each expert i:
grad_in[i] = grad_out[i] if |w[i]| <= amax[i] else 0
matches modelopt's `_fake_tensor_quant_backward` semantics.
"""
expert_idx = tl.program_id(axis=0)
block_idx = tl.program_id(axis=1)

w_ptr = tl.load(weight_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE))
grad_out_ptr = tl.load(grad_out_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE))
grad_in_ptr = tl.load(grad_in_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE))

amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32)

offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < elements_per_expert

# Stay in DTYPE (bf16/fp16) throughout — eliminates fp32 round-trip seen in
# the Btriton2 baseline that capped bwd bandwidth at ~4.2 TB/s vs cuda_ext's
# ~8 TB/s on B300. amax cast to DTYPE once; comparison done in low precision
# (amax values are O(1)-O(10), well within bf16 range).
w = tl.load(w_ptr + offsets, mask=mask, other=0.0)
g = tl.load(grad_out_ptr + offsets, mask=mask, other=0.0)
amax_dt = amax.to(DTYPE)

# Clip-aware STE: pass through gradient where |w| <= amax, else zero.
pass_through = tl.abs(w) <= amax_dt
grad_in = tl.where(pass_through, g, 0.0)

tl.store(grad_in_ptr + offsets, grad_in, mask=mask)


def _build_ptr_buf(tensors: list[torch.Tensor]) -> torch.Tensor:
"""Pack a list of tensors' .data_ptr() into a single int64 tensor on the same device."""
return torch.tensor(
[t.data_ptr() for t in tensors],
dtype=torch.int64,
device=tensors[0].device,
)


def grouped_axis0_fakequant(
weights: list[torch.Tensor],
amax_vec: torch.Tensor,
num_bits: int = 8,
narrow_range: bool = True,
) -> list[torch.Tensor]:
"""Apply per-expert axis-0 fake-quant in a single Triton launch.

Args:
weights: List of N expert weight tensors. Each must have the same shape
`[out, in]` and same dtype.
amax_vec: Per-expert amax buffer of shape `[N, 1, 1]` (or any shape where
element `i` is expert `i`'s amax). dtype should be float32 for
numerical headroom; the kernel casts to fp32 internally.
num_bits: integer bit-width for the fake-quant.
narrow_range: if True, output range is [-qmax, +qmax]; else [-qmax, +qmax-1].
modelopt's default is True.

Returns:
List of N quantized weight tensors, each the same shape and dtype as
the corresponding input.
"""
assert len(weights) >= 1, "grouped_axis0_fakequant requires at least one expert"
N = len(weights)
shape0 = weights[0].shape
dtype0 = weights[0].dtype
device0 = weights[0].device
elements_per_expert = weights[0].numel()
for w in weights[1:]:
assert w.shape == shape0, "all expert weights must share the same shape"
assert w.dtype == dtype0, "all expert weights must share the same dtype"
assert w.device == device0, "all expert weights must share the same device"

outputs = [torch.empty_like(w) for w in weights]

weight_ptrs = _build_ptr_buf(weights)
output_ptrs = _build_ptr_buf(outputs)

# BLOCK_SIZE=2048 was empirically best in the Btriton2 sweep — larger blocks
# (16384 + num_warps=8) regressed both fwd and bwd, likely from worse warp
# occupancy and load coalescing on B300.
BLOCK_SIZE = 2048
num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE)
grid = (N, num_blocks_per_expert)

with torch.cuda.device(device0):
_grouped_axis0_fakequant_fwd_kernel[grid](
weight_ptrs,
output_ptrs,
amax_vec,
elements_per_expert,
num_bits,
narrow_range=narrow_range,
DTYPE=_torch_dtype_to_tl(dtype0),
BLOCK_SIZE=BLOCK_SIZE,
)

return outputs


def grouped_axis0_fakequant_backward(
weights: list[torch.Tensor],
grad_outputs: list[torch.Tensor],
amax_vec: torch.Tensor,
) -> list[torch.Tensor]:
"""Apply per-expert clip-aware STE backward in a single Triton launch.

Matches modelopt's `_fake_tensor_quant_backward` semantics — gradient
passes through where `|w[i]| <= amax[i]`, else zero.

Args:
weights: List of N expert weight tensors (the original fwd inputs).
grad_outputs: List of N upstream gradients, one per expert.
amax_vec: Per-expert amax buffer (same shape as in fwd).

Returns:
List of N downstream gradients, one per expert.
"""
N = len(weights)
assert len(grad_outputs) == N
shape0 = weights[0].shape
dtype0 = weights[0].dtype
device0 = weights[0].device
elements_per_expert = weights[0].numel()

grad_inputs = [torch.empty_like(w) for w in weights]

weight_ptrs = _build_ptr_buf(weights)
grad_out_ptrs = _build_ptr_buf(grad_outputs)
grad_in_ptrs = _build_ptr_buf(grad_inputs)

BLOCK_SIZE = 2048
num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE)
grid = (N, num_blocks_per_expert)

with torch.cuda.device(device0):
_grouped_axis0_fakequant_bwd_kernel[grid](
weight_ptrs,
grad_out_ptrs,
grad_in_ptrs,
amax_vec,
elements_per_expert,
DTYPE=_torch_dtype_to_tl(dtype0),
BLOCK_SIZE=BLOCK_SIZE,
)

return grad_inputs
17 changes: 16 additions & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,21 @@ def sync_quantizer_amax_across_tp(
quantizer.sync_amax_across_distributed_group(parallel_state.tensor_parallel_group)

# Step 2: Sync amax across relevant parallelism (such as TP / EP)
def _weight_axes_for_sync(module, base_axes):
# For column-parallel TEGroupedLinear in per-expert mode (axis=0), the
# expert weights are NOT TP-sharded (etp=1 case) — axis=0 indexes experts,
# not output channels. Per-rank reductions still produce slightly
# divergent amax across TP (BF16/FP16 sums are not bit-identical across
# ranks even with the same input), and dist-checkpoint save treats
# _amax as replicated and captures only one rank's view. Without this
# sync, model_ref's per-rank-different amax mismatches the loaded
# model_test on every TP rank except the one whose value was saved.
if hasattr(module, "weight0"):
quantizer = getattr(module, "weight_quantizer", None)
if isinstance(quantizer, TensorQuantizer) and quantizer.axis in (0, (0,)):
return list(base_axes) + [0]
return base_axes
Comment on lines +360 to +373

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Include (0,) in the TP sync allowlist.

_weight_axes_for_sync() detects both 0 and (0,), but it only appends 0. The downstream check is an exact quantizer.axis in axes_for_sync, so tuple-normalized per-expert axes still skip TP amax sync and can diverge across ranks.

Proposed fix
         if hasattr(module, "weight0"):
             quantizer = getattr(module, "weight_quantizer", None)
             if isinstance(quantizer, TensorQuantizer) and quantizer.axis in (0, (0,)):
-                return list(base_axes) + [0]
+                return [*base_axes, 0, (0,)]
         return base_axes
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 360 - 373, The
function _weight_axes_for_sync currently checks quantizer.axis in (0, (0,)) but
only appends the integer 0 to base_axes, causing tuple-form axes like (0,) to be
missed by downstream exact-match checks; update _weight_axes_for_sync to
normalize the quantizer.axis (e.g., treat (0,) and 0 equivalently) or append
both representations (0 and (0,)) when weight_quantizer is a TensorQuantizer
with axis 0, so that axes_for_sync contains the same form as quantizer.axis and
TP amax sync occurs correctly for per-expert (weight0) modules.


for name, module in model.named_modules():
if getattr(module, "_parallel_state", None) is None:
continue
Expand All @@ -373,7 +388,7 @@ def sync_quantizer_amax_across_tp(
module.weight_quantizer,
name,
"weight_quantizer",
axes_for_sync=[None, -1],
axes_for_sync=_weight_axes_for_sync(module, [None, -1]),
parallel_state=module.parallel_state,
)

Expand Down
Loading
Loading