-
Notifications
You must be signed in to change notification settings - Fork 450
[OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
9003f0d
[OMNIML-5029] launcher: bind-mount Megatron-LM via megatron_install_path
hychiang-git 4af99c7
[OMNIML-5030] tests: enable sequence_parallel for MoE+TP in test_moe_…
hychiang-git e3a6f5c
[OMNIML-4998] quantization: per-expert weight amax on TEGroupedLinear…
hychiang-git e8d9eea
[OMNIML-4998] quantization: dist-checkpoint round-trip for per-expert…
hychiang-git 1e11b4b
[OMNIML-4998] quantization: gather per-expert amax before save, not d…
hychiang-git 61c68d6
[OMNIML-4998] quantization: AC-2 smoke uses TP=2 EP=2 (known-good MCo…
hychiang-git 6751331
[OMNIML-4998] quantization: mirror test_moe_sharded_state_dict setup …
hychiang-git 8fd0d53
[OMNIML-4998] quantization: per-expert _amax shape on TEGrouped restore
hychiang-git c13ad0b
[OMNIML-4998] quantization: reshape loaded per-expert _amax at the le…
hychiang-git 82b4979
[OMNIML-4998] quantization: TP-sync per-expert _amax in max_calibrate
hychiang-git d7ccf0a
[OMNIML-5072] quantization: use unbind(0) to unpack axis-0 per-expert…
hychiang-git 26307c7
[OMNIML-5072] quantization: no-stack autograd.Function for per-expert…
hychiang-git 4c164f9
[OMNIML-5072] quantization: fuse STE backward in _GroupedAxis0FakeQua…
hychiang-git e4945a8
[OMNIML-5072] revert AC5 no-stack Function — negative result on B300 …
hychiang-git 0bf4838
[OMNIML-5072] Triton per-expert axis-0 fake-quant kernel for TEGroupe…
hychiang-git 1080e68
[OMNIML-5072] Triton fwd: round-half-to-even via libdevice.rint
hychiang-git File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
80 changes: 80 additions & 0 deletions
80
modelopt/torch/kernels/quantization/gemm/VALIDATION_TODO.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| # OMNIML-5072 Option B — Triton kernel validation record | ||
|
|
||
| **Status: VALIDATED 2026-06-12.** Kernel lives at `grouped_axis0_fakequant.py`, | ||
| wired into `te_grouped_quantized_linear_fn` via `_GroupedAxis0FakeQuantFn`. | ||
|
|
||
| ## Validation summary | ||
|
|
||
| ### 1. Numerical fidelity | ||
|
|
||
| Compared `grouped_axis0_fakequant(weights, amax)` against per-expert | ||
| `cuda_ext.fake_tensor_quant(w_i, amax_i)` (A's path) at the Ultra production | ||
| shape (N=32, [5120, 8192] bfloat16). See | ||
| `nmm-sandbox/studies/omniml-5064/microbench/parity_a_vs_btriton.py`. | ||
|
|
||
| ``` | ||
| Check max_abs_err Notes | ||
| ───────────────────────────────── ─────────── ───────────────────────────── | ||
| forward output 0.03125 = 1 ULP at this quant scale | ||
| (rounding mode: Triton does | ||
| round-half-away, cuda_ext does | ||
| round-half-to-even; differs only | ||
| on exact half-step boundaries) | ||
| backward grad (pass_through_bwd) 0.0 bit-exact identical to A | ||
| ``` | ||
|
|
||
| ### 2. Bench (OMNIML-5064 microbench, B300, aws-cmh) | ||
|
|
||
| Btriton beats A on every column at every cell: | ||
|
|
||
| ``` | ||
| Cell impl fwd_us bwd_us step_us Notes | ||
| ─────────────────────────────────────────────────────────────────────────────── | ||
| Nano EP=1 A 20,221 5,807 27,422 | ||
| Btriton 2,670 3,591 7,472 fwd 7.6× win | ||
| Super EP=1 A 57,736 17,736 84,289 | ||
| Btriton 9,584 14,923 33,242 fwd 6.0× win | ||
| Ultra mock16 A 5,835 547 8,246 | ||
| Btriton 1,795 549 4,210 fwd 3.25× win; bwd tied | ||
| Nano EP=4 A 3,985 1,340 5,693 | ||
| Btriton 1,208 1,221 2,815 fwd 3.30× win | ||
| Super EP=4 A 14,094 4,187 20,576 | ||
| Btriton 2,110 3,326 7,710 fwd 6.68× win | ||
| Ultra EP=4 A 23,178 2,064 32,686 | ||
| Btriton 6,730 2,059 16,220 fwd 3.44× win | ||
| Ultra EP=8 A 11,695 1,070 16,482 2-node | ||
| Btriton 3,515 1,075 8,314 fwd 3.33× win | ||
| ``` | ||
|
|
||
| ### 3. Distributed validation | ||
|
|
||
| - ✓ EP=1 single-rank (mock-EP=16 emulation of EP=16 deployment) | ||
| - ✓ EP=4 single-node 4-rank (Nano / Super / Ultra) | ||
| - ✓ EP=8 multi-node 2-node 8-rank (Ultra) | ||
| - `peak_mb` matches A's within 3 MB across cells (same layer instantiation | ||
| signal) | ||
|
|
||
| ### 4. Hardware coverage | ||
|
|
||
| Tested on B300 (compute 10.0+, aws-cmh). The kernel uses no SM-specific | ||
| intrinsics (no MMA, no async copy, no tensor cores) — pure load/store + | ||
| arithmetic in `triton.language`. Should compile and run on any GPU where | ||
| `torch.cuda.is_available()` returns True and Triton is installed; the | ||
| `IS_AVAILABLE` guard in `__init__.py` skips it otherwise. | ||
|
|
||
| ## Design notes | ||
|
|
||
| The kernel takes N expert weights as a `[N]` int64 tensor of base pointers | ||
| (via each tensor's `.data_ptr()`). Each Triton program reads its expert's | ||
| pointer, then strides through a block of elements. Grid: `(N, num_blocks_per_expert)`. | ||
| This eliminates the `torch.stack` memcopy on the forward path (~2.7 GB at | ||
| Ultra scale). | ||
|
|
||
| Backward honors modelopt's `pass_through_bwd` flag (`config.py` default | ||
| `True`). When set, the backward returns `grad_outputs` unchanged with zero | ||
| kernel launches — matching `_fake_quant_backward_function`'s no-save | ||
| behavior. When `False`, the clip-aware STE Triton backward kernel runs. | ||
|
|
||
| Soft-gated at the call site in `te_grouped_quantized_linear_fn`: falls back | ||
| to the stack-then-quant-then-unbind path when the Triton kernels aren't | ||
| available or when calibration is active (`q._if_calib`). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
257 changes: 257 additions & 0 deletions
257
modelopt/torch/kernels/quantization/gemm/grouped_axis0_fakequant.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,257 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Fused per-expert axis-0 fake-quant Triton kernels for TEGroupedLinear. | ||
|
|
||
| Replaces the stack-then-quantize-then-unbind pattern in modelopt's TEGrouped | ||
| plugin (`te_grouped_quantized_linear_fn`) with a single Triton launch that | ||
| processes N expert weights in place, with no contiguous-tensor staging. | ||
|
|
||
| Design — tensor of pointers | ||
| --------------------------- | ||
|
|
||
| The N expert weights live as separate Parameters (one per expert), so they're | ||
| NOT contiguous in HBM. To avoid a `torch.stack` memcopy (the cost AC5 | ||
| characterized on OMNIML-5064), we feed the kernel a `[N]` int64 tensor of | ||
| expert base pointers. Each Triton program reads its expert's pointer first, | ||
| then strides through a block of elements at that address. | ||
|
|
||
| Grid: (N, num_blocks_per_expert). | ||
| Program 0 of axis 0 → expert 0, program 1 → expert 1, etc. | ||
|
|
||
| See OMNIML-5072 AC5 (Option B follow-up) for the motivation. | ||
|
|
||
| VALIDATION STATUS (2026-06-11): kernel implemented, numerical fidelity NOT | ||
| yet validated against modelopt's reference `fake_quant_impl`, and bench | ||
| performance NOT yet measured. See VALIDATION_TODO.md in this directory. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
| from triton.language.extra.cuda import libdevice | ||
|
|
||
| __all__ = ["grouped_axis0_fakequant", "grouped_axis0_fakequant_backward"] | ||
|
|
||
|
|
||
| def _torch_dtype_to_tl(dtype: torch.dtype): | ||
| """Map a torch dtype to its Triton-language equivalent.""" | ||
| return { | ||
| torch.float32: tl.float32, | ||
| torch.bfloat16: tl.bfloat16, | ||
| torch.float16: tl.float16, | ||
| }[dtype] | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _grouped_axis0_fakequant_fwd_kernel( | ||
| weight_ptrs_buf, # int64 [N] — N expert base pointers (cast from .data_ptr()) | ||
| output_ptrs_buf, # int64 [N] — N output base pointers | ||
| amax_vec_ptr, # [N, 1, 1] (or anything with N as the leading dim) | ||
| elements_per_expert, | ||
| num_bits, | ||
| narrow_range: tl.constexpr, | ||
| DTYPE: tl.constexpr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| expert_idx = tl.program_id(axis=0) | ||
| block_idx = tl.program_id(axis=1) | ||
|
|
||
| # Per-expert base pointers (loaded once per program). | ||
| w_int = tl.load(weight_ptrs_buf + expert_idx) | ||
| out_int = tl.load(output_ptrs_buf + expert_idx) | ||
| w_ptr = w_int.to(tl.pointer_type(DTYPE)) | ||
| out_ptr = out_int.to(tl.pointer_type(DTYPE)) | ||
|
|
||
| # Per-expert amax → quant scale. | ||
| # amax is stored as fp32; convert to working precision. | ||
| amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32) | ||
| # qmax = 2^(num_bits-1) - 1 when narrow_range else 2^(num_bits-1) | ||
| # For num_bits=8 narrow_range=True (modelopt default): qmax=127 | ||
| qmax = ((1 << (num_bits - 1)) - 1) if narrow_range else (1 << (num_bits - 1)) | ||
| qmin = -qmax if narrow_range else -qmax # signed symmetric | ||
| scale = amax / qmax | ||
|
|
||
| # Block of elements within this expert. | ||
| offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < elements_per_expert | ||
|
|
||
| x = tl.load(w_ptr + offsets, mask=mask, other=0.0).to(tl.float32) | ||
|
|
||
| # Fake-quant: round(clip(x / scale)) * scale. | ||
| # Use scale guards to avoid div-by-zero before _amax is calibrated (early | ||
| # batches may carry an _amax of 0; matches modelopt's fake_tensor_quant | ||
| # behavior of passing through unchanged). | ||
| safe_scale = tl.where(scale > 0.0, scale, 1.0) | ||
| q = x / safe_scale | ||
| q = tl.maximum(tl.minimum(q, qmax), qmin) | ||
| # Round-half-to-even (banker's), matching cuda_ext.fake_tensor_quant exactly. | ||
| # libdevice.rint is CUDA's __rint* builtin. Imported via the same path that | ||
| # modelopt's nvfp4_quant.py uses (triton.language.extra.cuda.libdevice). | ||
| q_rounded = libdevice.rint(q) | ||
| out = tl.where(scale > 0.0, q_rounded * scale, x) | ||
|
|
||
| tl.store(out_ptr + offsets, out.to(DTYPE), mask=mask) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _grouped_axis0_fakequant_bwd_kernel( | ||
| weight_ptrs_buf, # int64 [N] — same buffer as fwd | ||
| grad_out_ptrs_buf, # int64 [N] — upstream grad pointers (per expert) | ||
| grad_in_ptrs_buf, # int64 [N] — output: downstream grad pointers | ||
| amax_vec_ptr, # [N, ...] — same buffer as fwd | ||
| elements_per_expert, | ||
| DTYPE: tl.constexpr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| """Clip-aware STE backward. | ||
|
|
||
| For each expert i: | ||
| grad_in[i] = grad_out[i] if |w[i]| <= amax[i] else 0 | ||
| matches modelopt's `_fake_tensor_quant_backward` semantics. | ||
| """ | ||
| expert_idx = tl.program_id(axis=0) | ||
| block_idx = tl.program_id(axis=1) | ||
|
|
||
| w_ptr = tl.load(weight_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE)) | ||
| grad_out_ptr = tl.load(grad_out_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE)) | ||
| grad_in_ptr = tl.load(grad_in_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE)) | ||
|
|
||
| amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32) | ||
|
|
||
| offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < elements_per_expert | ||
|
|
||
| # Stay in DTYPE (bf16/fp16) throughout — eliminates fp32 round-trip seen in | ||
| # the Btriton2 baseline that capped bwd bandwidth at ~4.2 TB/s vs cuda_ext's | ||
| # ~8 TB/s on B300. amax cast to DTYPE once; comparison done in low precision | ||
| # (amax values are O(1)-O(10), well within bf16 range). | ||
| w = tl.load(w_ptr + offsets, mask=mask, other=0.0) | ||
| g = tl.load(grad_out_ptr + offsets, mask=mask, other=0.0) | ||
| amax_dt = amax.to(DTYPE) | ||
|
|
||
| # Clip-aware STE: pass through gradient where |w| <= amax, else zero. | ||
| pass_through = tl.abs(w) <= amax_dt | ||
| grad_in = tl.where(pass_through, g, 0.0) | ||
|
|
||
| tl.store(grad_in_ptr + offsets, grad_in, mask=mask) | ||
|
|
||
|
|
||
| def _build_ptr_buf(tensors: list[torch.Tensor]) -> torch.Tensor: | ||
| """Pack a list of tensors' .data_ptr() into a single int64 tensor on the same device.""" | ||
| return torch.tensor( | ||
| [t.data_ptr() for t in tensors], | ||
| dtype=torch.int64, | ||
| device=tensors[0].device, | ||
| ) | ||
|
|
||
|
|
||
| def grouped_axis0_fakequant( | ||
| weights: list[torch.Tensor], | ||
| amax_vec: torch.Tensor, | ||
| num_bits: int = 8, | ||
| narrow_range: bool = True, | ||
| ) -> list[torch.Tensor]: | ||
| """Apply per-expert axis-0 fake-quant in a single Triton launch. | ||
|
|
||
| Args: | ||
| weights: List of N expert weight tensors. Each must have the same shape | ||
| `[out, in]` and same dtype. | ||
| amax_vec: Per-expert amax buffer of shape `[N, 1, 1]` (or any shape where | ||
| element `i` is expert `i`'s amax). dtype should be float32 for | ||
| numerical headroom; the kernel casts to fp32 internally. | ||
| num_bits: integer bit-width for the fake-quant. | ||
| narrow_range: if True, output range is [-qmax, +qmax]; else [-qmax, +qmax-1]. | ||
| modelopt's default is True. | ||
|
|
||
| Returns: | ||
| List of N quantized weight tensors, each the same shape and dtype as | ||
| the corresponding input. | ||
| """ | ||
| assert len(weights) >= 1, "grouped_axis0_fakequant requires at least one expert" | ||
| N = len(weights) | ||
| shape0 = weights[0].shape | ||
| dtype0 = weights[0].dtype | ||
| device0 = weights[0].device | ||
| elements_per_expert = weights[0].numel() | ||
| for w in weights[1:]: | ||
| assert w.shape == shape0, "all expert weights must share the same shape" | ||
| assert w.dtype == dtype0, "all expert weights must share the same dtype" | ||
| assert w.device == device0, "all expert weights must share the same device" | ||
|
|
||
| outputs = [torch.empty_like(w) for w in weights] | ||
|
|
||
| weight_ptrs = _build_ptr_buf(weights) | ||
| output_ptrs = _build_ptr_buf(outputs) | ||
|
|
||
| # BLOCK_SIZE=2048 was empirically best in the Btriton2 sweep — larger blocks | ||
| # (16384 + num_warps=8) regressed both fwd and bwd, likely from worse warp | ||
| # occupancy and load coalescing on B300. | ||
| BLOCK_SIZE = 2048 | ||
| num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE) | ||
| grid = (N, num_blocks_per_expert) | ||
|
|
||
| with torch.cuda.device(device0): | ||
| _grouped_axis0_fakequant_fwd_kernel[grid]( | ||
| weight_ptrs, | ||
| output_ptrs, | ||
| amax_vec, | ||
| elements_per_expert, | ||
| num_bits, | ||
| narrow_range=narrow_range, | ||
| DTYPE=_torch_dtype_to_tl(dtype0), | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
|
|
||
| return outputs | ||
|
|
||
|
|
||
| def grouped_axis0_fakequant_backward( | ||
| weights: list[torch.Tensor], | ||
| grad_outputs: list[torch.Tensor], | ||
| amax_vec: torch.Tensor, | ||
| ) -> list[torch.Tensor]: | ||
| """Apply per-expert clip-aware STE backward in a single Triton launch. | ||
|
|
||
| Matches modelopt's `_fake_tensor_quant_backward` semantics — gradient | ||
| passes through where `|w[i]| <= amax[i]`, else zero. | ||
|
|
||
| Args: | ||
| weights: List of N expert weight tensors (the original fwd inputs). | ||
| grad_outputs: List of N upstream gradients, one per expert. | ||
| amax_vec: Per-expert amax buffer (same shape as in fwd). | ||
|
|
||
| Returns: | ||
| List of N downstream gradients, one per expert. | ||
| """ | ||
| N = len(weights) | ||
| assert len(grad_outputs) == N | ||
| shape0 = weights[0].shape | ||
| dtype0 = weights[0].dtype | ||
| device0 = weights[0].device | ||
| elements_per_expert = weights[0].numel() | ||
|
|
||
| grad_inputs = [torch.empty_like(w) for w in weights] | ||
|
|
||
| weight_ptrs = _build_ptr_buf(weights) | ||
| grad_out_ptrs = _build_ptr_buf(grad_outputs) | ||
| grad_in_ptrs = _build_ptr_buf(grad_inputs) | ||
|
|
||
| BLOCK_SIZE = 2048 | ||
| num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE) | ||
| grid = (N, num_blocks_per_expert) | ||
|
|
||
| with torch.cuda.device(device0): | ||
| _grouped_axis0_fakequant_bwd_kernel[grid]( | ||
| weight_ptrs, | ||
| grad_out_ptrs, | ||
| grad_in_ptrs, | ||
| amax_vec, | ||
| elements_per_expert, | ||
| DTYPE=_torch_dtype_to_tl(dtype0), | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
|
|
||
| return grad_inputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include
(0,)in the TP sync allowlist._weight_axes_for_sync()detects both0and(0,), but it only appends0. The downstream check is an exactquantizer.axis in axes_for_sync, so tuple-normalized per-expert axes still skip TPamaxsync and can diverge across ranks.Proposed fix
if hasattr(module, "weight0"): quantizer = getattr(module, "weight_quantizer", None) if isinstance(quantizer, TensorQuantizer) and quantizer.axis in (0, (0,)): - return list(base_axes) + [0] + return [*base_axes, 0, (0,)] return base_axes🤖 Prompt for AI Agents