Skip to content

[OMNIML-4998] Per-expert weight quantization on TEGroupedLinear#1671

Draft
hychiang-git wants to merge 14 commits into
mainfrom
hungyuehc/omniml-4998-umbrella
Draft

[OMNIML-4998] Per-expert weight quantization on TEGroupedLinear#1671
hychiang-git wants to merge 14 commits into
mainfrom
hungyuehc/omniml-4998-umbrella

Conversation

@hychiang-git

@hychiang-git hychiang-git commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: New feature

Per-expert (axis=0) weight quantization on _QuantTEGroupedLinear, plus the cluster infrastructure to verify it. Consolidates three Jira tickets onto one branch:

  • OMNIML-4998 — Relax _QuantMegatronTEGroupedLinear._process_quantizer_amax's per-tensor restriction. Per-expert _amax of shape [num_gemms, 1, 1] is calibrated via iter_weights_for_calibration's stacked weight, broadcast through the existing axis-N machinery in TensorQuantizer.forward, and round-tripped via sharded_state_dict save/restore. Per-tensor (axis=None) path remains bit-for-bit equivalent.
  • OMNIML-5029megatron_install_path field on SlurmConfig / aws_cmh_slurm_factory. Bind-mounts the packaged Megatron-LM so make_tp_sharded_tensor_for_checkpoint(..., allow_shape_mismatch=True) resolves against nemo:25.11.nemotron_3_nano (previously deadlocked the dist-checkpoint metadata exchange for 10 minutes).
  • OMNIML-5030 — Test fix: enable sequence_parallel=True in the MoE+TP test_moe_sharded_state_dict config so the TEGrouped path doesn't trip the post-OMNIML-5029 assertion.

Usage

import modelopt.torch.quantization as mtq

config = {
    "algorithm": "max",
    "quant_cfg": [
        {"quantizer_name": "*", "enable": False},
        {
            "quantizer_name": "*experts.linear_fc*.weight_quantizer",
            "cfg": {"num_bits": 8, "axis": 0, "enable": True},
        },
    ],
}
model = mtq.quantize(model, config, forward)

Testing

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py::test_te_grouped_per_expert_sharded_state_dict (new) — per-expert axis=0 calibration + dist-checkpoint save/restore round-trip on tp=2 ep=2 etp=1 num_moe_experts=4. PASSED 4/4 ranks on aws-cmh (SLURM 511920, experiment cicd_1781150929, container nvcr.io/nvidia/nemo:25.11.nemotron_3_nano).

Existing tests continue to pass:

  • test_moe_sharded_state_dict[*-True] with the OMNIML-5030 sequence_parallel fix.
  • Per-tensor (axis=None) paths — bit-for-bit equivalent (no behavior change).

Cluster yaml: services/model-optimizer/test/omniml_4998_per_expert_amax_smoke.yaml on nmm-sandbox hungyuehc/omniml-5029 @ 24860c3.

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅ — per-tensor (axis=None) path is bit-for-bit equivalent; per-expert is a new opt-in via existing axis-N machinery, not a behavior change for existing users.
  • New PIP dependency / copied code?: N/A.
  • Did you write any new necessary tests?: ✅ — test_te_grouped_per_expert_sharded_state_dict.
  • Did you update Changelog?: ❌ — TBD; will add a changelog entry before marking Ready-for-review.
  • Did you get Claude approval on this PR?: N/A — will run /claude review before marking Ready-for-review.

Additional Information

Related Jira:

  • OMNIML-4998 (primary scope; AC-1 and AC-2 met)
  • OMNIML-5029 (launcher fix)
  • OMNIML-5030 (test fix)
  • OMNIML-5051 (follow-up: per-expert quant + TE-grouped MoE LoRA compose)

Relationship to PR #1550 (jennifchen/te_per_expert, WIP): PR #1550 explores the same goal with a different design — opt-in env-var MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 to create N separate weight_quantizer_i instances. This PR uses the single-quantizer + axis=0 design instead, which reuses modelopt's existing axis-N machinery without a new env-var or per-gemm state. The two designs serve overlapping but not identical needs:

PR #1550 is currently WIP and CONFLICTING vs main. Open to discussion on whether to merge approaches; this PR is the focused common-case implementation.

Summary by CodeRabbit

  • New Features

    • Added per-expert weight quantization support for distributed checkpointing in MoE models.
    • Integrated Megatron-LM mounting into launcher tools for both Slurm and Docker environments.
  • Improvements

    • Enhanced tensor-parallel and expert-parallel weight quantization synchronization during calibration and checkpointing.
  • Tests

    • Added regression test for per-expert weight quantization across distributed configurations.
    • Expanded launcher integration test coverage for Docker and Slurm environments.

hychiang-git and others added 7 commits June 9, 2026 15:11
Adds a megatron_install_path field on SlurmConfig and a symmetric
parameter on slurm_factory, mirroring the existing modelopt_install_path
plumbing. build_slurm_executor and build_docker_executor now append a
{megatron_src}:{megatron_dst} bind-mount onto container_mounts so the
packaged modules/Megatron-LM/megatron source shadows the container's
pre-installed Megatron-LM at /usr/local/lib/python3.12/dist-packages/megatron.

This unblocks running modelopt-from-source against containers whose
bundled Megatron-LM is too old for newer modelopt features — concretely,
test_moe_sharded_state_dict against nemo:25.11.nemotron_3_nano, where
make_tp_sharded_tensor_for_checkpoint silently drops allow_shape_mismatch
and the resulting per-rank shape divergence deadlocks dist-checkpoint
metadata exchange. Path 3 from the OMNIML-5029 fix-paths analysis.

Tests:
  - test_slurm_config.py: assert default megatron_install_path
  - test_slurm_executor.py: assert megatron mount appears in container_mounts
  - test_docker_execution.py: new test_megatron_mount for the Docker path
  Full launcher suite: 63 passed, 2 skipped, 0 failed.
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…sharded_state_dict

Megatron-LM's MoELayer.forward asserts sequence_parallel=True whenever the
attn_tp_group has size > 1 (see megatron/core/transformer/moe/moe_layer.py).
test_moe_sharded_state_dict's TEGroupedMLP parameterization uses tp_size=2,
so the assertion fires before the dist-checkpoint logic under test even runs.
The SequentialMLP branch sidesteps it by dropping to tp_size=1.

The hybrid-Mamba model path in _gpt_model_provider already passes
sequence_parallel=(tp_size > 1) for exactly this reason; the plain-GPT
branch did not. This commit propagates the same convention:

  - Add a sequence_parallel: bool = False kwarg to get_mcore_gpt_model
    and forward it into TransformerConfig (replacing the previously
    hardcoded sequence_parallel=False).
  - In _gpt_model_provider's get_mcore_gpt_model call, pass
    sequence_parallel=(tp_size > 1).

After this change the GroupedMLP variants (FP8_DEFAULT_CFG and
NVFP4_DEFAULT_CFG with moe_grouped_gemm=True) reach the actual
sharded_state_dict logic instead of erroring out at MoE forward.
Surfaced by the OMNIML-5029 cluster smoke run (job 496032) once the
launcher bind-mount let the test get past the Megatron-LM import phase.

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
… (axis=0)

Lift the per-tensor-only restriction on _QuantTEGroupedLinear's weight
quantizer so it can compute one amax per expert. The new path is gated on
axis=0 in the standard QuantizerAttributeConfig — default (axis=None)
keeps the legacy single-amax behavior bit-for-bit.

  * _QuantTEGroupedLinear._is_per_expert_weight_quant() — small helper
    that returns True when the existing weight_quantizer carries axis=0.
  * iter_weights_for_calibration: per-expert mode yields one stacked
    [num_gemms, out, in] tensor instead of N separate slices so the
    quantizer's axis-0 reduction produces _amax of shape [N, 1, 1].
  * te_grouped_quantized_linear_fn: per-expert forward stacks the N
    expert weights, fake-quants once, and indexes back; broadcasting
    handles the per-expert scale via _amax[:, None, None].
  * _QuantMegatronTEGroupedLinear._process_quantizer_amax: accept either
    a scalar (per-tensor, legacy) or a num_gemms-element tensor
    (per-expert). Anything else raises with a clear message.

Closes the long-standing "support a unique quantizer for each gemm"
TODO in _QuantTEGroupedLinear._setup at the per-expert granularity.
Per-channel-within-expert and per-block per-expert (NVFP4) are
deliberately out of scope and tracked separately.

sharded_state_dict / _get_shard_axis_dict wiring for dist-checkpoint
round-trip of the new per-expert amax is the next commit (AC-2). GPU
smoke tests under tests/gpu_megatron/torch/quantization/plugins/ land
with that commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
… amax

Builds on the prior commit (per-expert axis=0 forward + calibration) by
plumbing per-expert weight_quantizer._amax through sharded_state_dict and
load_distributed_checkpoint so each EP rank reconstructs the same local
amax bit-for-bit after a save/restore.

The design avoids encoding EP sharding in the checkpoint metadata (which
make_sharded_tensors_for_checkpoint only knows about TP). Instead each
TEGroupedLinear all-gathers its local [num_gemms] amax across the
expert_model_parallel_group before save, so every rank stores an
identical [num_gemms_global] tensor that the dist-checkpoint framework
can safely replicate across DP/TP. On restore each rank narrows back to
its local slice before the base class's view_as() reshape.

Three pieces on _QuantMegatronTEGroupedLinear:

  * _ep_group() — returns the EP group iff it's initialized AND has >1
    rank; gives a clean None path for EP=1 and ad-hoc unit tests where
    parallel_state isn't wired up.
  * _process_quantizer_amax: per-expert branch all-gathers via
    torch.distributed.all_gather then concatenates; per-tensor branch
    unchanged.
  * _load_from_state_dict: per-expert branch narrows the global tensor to
    [ep_rank * num_gemms : (ep_rank+1) * num_gemms] before passing to
    super(). _extra_state filter preserved.
  * _get_shard_axis_dict: strips the TP-axis-0 marker that the
    column-parallel parent would otherwise add for any non-None
    weight_quantizer.axis — correct for per-channel-along-output, but
    wrong for our [num_gemms_global] amax.

Smoke test: test_te_grouped_per_expert_sharded_state_dict under
tests/gpu_megatron/torch/quantization/plugins/. Builds a TEGroupedMLP
GPT with num_moe_experts=4, configures a TE-grouped-only per-expert
config (axis=0 weight quantizer, everything else disabled), runs through
_test_sharded_state_dict on dist_workers_size_4 with tp_size=1 ep_size=2.
Verifies the round-trip helper's bitwise state_dict comparison passes
with the new amax shape. Re-enabling INT4_BLOCKWISE_WEIGHT_ONLY_CFG on
TEGroupedMLP and the per-channel-within-expert / NVFP4 block scales
paths are deliberately out of scope and remain skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…uring

The previous version did the EP all-gather inside _process_quantizer_amax,
which is called per-key inside _MegatronParallelLinear.sharded_state_dict's
state_dict iteration. Megatron's dist-checkpoint save runs its own
default-PG ALLGATHER metadata exchanges across the world group from
inside that same iteration -- the two interleave at NCCL and one rank
ends up stuck in our EP gather while the others have already advanced to
the default-PG ALLGATHER, producing a 10-minute watchdog timeout.

Move the gather to a single call at the top of sharded_state_dict, before
super().sharded_state_dict starts traversing state_dict. The gathered
[num_gemms_global] tensor is cached on a temporary attribute that
_process_quantizer_amax consumes; the live weight_quantizer._amax buffer
stays at its [num_gemms_local, 1, 1] shape so forward keeps broadcasting
correctly against the stacked [num_gemms_local, out, in] expert weights.
The cache is cleared in a finally block so it doesn't leak after save.

Net effect: per-rank flow during save is now
  1. EP all-gather (one collective per TEGroupedLinear layer, on EP group)
  2. super().sharded_state_dict iterates state_dict (no collectives)
  3. Megatron's dist-checkpoint save (default-PG collectives)
The three phases no longer interleave, so the NCCL deadlock can't form.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…re layout)

Earlier attempts at TP=1 EP=2 hit a 10-minute NCCL watchdog timeout on a
default-PG ALLGATHER (SeqNum=5, NumelIn=1 NumelOut=4) regardless of
weight_quantizer.axis. Confirmed by an explicit axis=None diagnostic run
that reproduced the same deadlock signature, so the hang is in the MCore
+ TEGroupedMLP + dist-checkpoint stack at TP=1 EP=2 -- not in OMNIML-4998's
per-expert plumbing.

The TP=1 EP=2 layout has no existing modelopt test that exercises
dist-checkpoint with grouped MoE: `test_moe_sharded_state_dict` runs
TP=2 EP=2 for grouped_gemm=True, and `test_te_grouped_vs_sequential_quantize`
runs TP=1 EP=2 but no dist-checkpoint. So we're not regressing anything;
we're choosing a layout the surrounding code actually supports.

Switching the AC-2 smoke to TP=2 EP=2 keeps the EP plumbing exercised
(still 2 local experts per EP rank -> the EP all-gather in
_gather_global_per_expert_amax still actually moves data) while staying
on a Megatron/MCore layout that is known to work for grouped MoE
dist-checkpoint save+restore. Followup: open an MCore-side issue for
the TP=1 EP=2 hang and revisit this test config once it lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…exactly

The TP=2 EP=2 run was still hitting the same SeqNum=5 ALLGATHER hang as
TP=1 EP=2. But test_moe_sharded_state_dict already passes at TP=2 EP=2
with FP8_DEFAULT_CFG / NVFP4_DEFAULT_CFG -- so the test infrastructure
works at this layout. Three things differed between our test and the
working one: the fixture (dist_workers_size_4 vs dist_workers),
hidden_size (32 vs 256), and the quant_cfg.

Mirror the working test's fixture and hidden_size so the only varying
thing in this run is our axis=0 quant_cfg. If it now passes, our config
is fine and the prior hangs were from the size_4 fixture or the tiny
hidden_size 32 (which is degenerate -- head_dim=8 with 4 attn heads).
If it still hangs, the per-expert config itself is at fault and we
narrow further.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 10, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ca5243a6-8219-4d4a-897a-e49829c6ecdc

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds per-expert weight quantization support for Transformer Engine GroupedLinear modules within Megatron, enabling proper distributed checkpoint handling across expert-parallel ranks. It adjusts TP amax synchronization axes, implements EP-aware gather/load logic with amax reshaping, and extends launcher infrastructure to mount Megatron-LM in execution environments.

Changes

TE GroupedLinear Per-Expert Quantization and Launcher Infrastructure

Layer / File(s) Summary
TP amax sync axis adjustment for per-expert modules
modelopt/torch/quantization/model_calib.py
Calibration adds _weight_axes_for_sync helper to compute TP amax sync axes dynamically, extending base sync axes when TEGroupedLinear per-expert mode (detected via weight0 attribute) has weight quantizer axis 0.
Megatron EP gather and sharded state dict integration
modelopt/torch/quantization/plugins/megatron.py
Megatron plugin overrides sharded_state_dict() to all-gather local per-expert _amax tensors into global [num_gemms_global] shape without mutating the live buffer, and implements _load_from_state_dict() to narrow saved global amax back to per-rank local slices. _get_shard_axis_dict() removes TP shard markers for per-expert amax; _process_quantizer_amax() emits cached gathered global amax for per-expert entries.
Transformer Engine amax reshape and per-expert quantization hooks
modelopt/torch/quantization/plugins/transformer_engine.py
TE plugin adds _reshape_loaded_amax_to_buffer_shape pre-hook to restore amax tensor shapes during load. _is_per_expert_weight_quant() detects axis-0 per-expert mode; modelopt_post_restore stacks expert weights in per-expert mode; iter_weights_for_calibration yields stacked weights once per-expert; te_grouped_quantized_linear_fn stacks weights, quantizes once, and scatters slices back in per-expert mode.
Test model builder and per-expert regression test
tests/_test_utils/torch/megatron/models.py, tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
Model builder exposes sequence_parallel parameter; test provider enables sequence parallel when tp_size > 1 for non-hybrid paths. New regression test exercises TEGrouped MoE checkpoint round-trip with fixed TP/EP layout and axis-0 weight quantization.
Launcher config and executor Megatron-LM mounting
tools/launcher/slurm_config.py, tools/launcher/core.py
Slurm config adds megatron_install_path field and factory parameter. Slurm executor mounts Megatron host-to-container from config path. Docker executor adds megatron_src_path parameter (defaults to modules/Megatron-LM/megatron) and mounts alongside existing mounts.
Launcher test coverage for Megatron mounting
tools/launcher/tests/test_slurm_config.py, tools/launcher/tests/test_docker_execution.py, tools/launcher/tests/test_slurm_executor.py
Config test asserts Megatron install path defaults. Docker execution tests add megatron_install_path to mocks and verify custom megatron mount mapping. Slurm executor tests update fixtures with megatron_install_path and verify Megatron mount entries in container mounts.

Sequence Diagram

sequenceDiagram
  participant Calib as max_calibrate
  participant Module as TEGroupedLinear
  participant Megatron as MegatronPlugin
  participant SaveExec as sharded_state_dict
  participant EPGather as _gather_global_per_expert_amax
  participant Checkpoint as Checkpoint
  participant LoadExec as _load_from_state_dict
  participant TE as TEPlugin

  Calib->>Module: detect per-expert (weight0)
  Calib->>Calib: _weight_axes_for_sync returns axes including 0
  Calib->>Module: sync_quantizer_amax_across_tp with adjusted axes

  SaveExec->>EPGather: all-gather local weight._amax
  EPGather->>Checkpoint: save global [num_gemms_global] tensor
  SaveExec->>Checkpoint: cache gathered amax on module

  LoadExec->>Checkpoint: load saved global [num_gemms_global] amax
  LoadExec->>LoadExec: narrow to local [num_gemms] slice per EP rank
  LoadExec->>TE: delegate to base loader

  TE->>Module: _reshape_loaded_amax_to_buffer_shape hook
  TE->>Module: restore amax to live buffer shape
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title '[OMNIML-4998] Per-expert weight quantization on TEGroupedLinear' clearly and concisely describes the main feature being added—per-expert weight quantization support for TEGroupedLinear modules. It directly reflects the primary change across the quantization, test, and infrastructure files.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed PR contains no unsafe torch.load(), numpy.load(), hardcoded trust_remote_code, eval/exec on untrusted input, new nosec comments, hardcoded credentials, new PIP dependencies, or unsafe deserializati...

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch hungyuehc/omniml-4998-umbrella

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1671/

Built to branch gh-pages at 2026-06-12 04:35 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov

codecov Bot commented Jun 10, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 93 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.56%. Comparing base (2c52e7b) to head (e4945a8).
⚠️ Report is 32 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/plugins/megatron.py 0.00% 48 Missing ⚠️
...t/torch/quantization/plugins/transformer_engine.py 0.00% 37 Missing ⚠️
modelopt/torch/quantization/model_calib.py 0.00% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1671      +/-   ##
==========================================
- Coverage   77.48%   75.56%   -1.93%     
==========================================
  Files         489      511      +22     
  Lines       54415    60054    +5639     
==========================================
+ Hits        42165    45379    +3214     
- Misses      12250    14675    +2425     
Flag Coverage Δ
unit 54.26% <0.00%> (+0.25%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

hychiang-git and others added 3 commits June 10, 2026 15:16
The dist-checkpoint LOAD path was building model_test's placeholder _amax
from self.weight0 -- the un-stacked 2D [out, in] view of one expert's
weight. With axis=0 that yields _amax of shape [out, 1] (e.g. [1024, 1] for
fc1 at hidden=256 / tp=2 / ep=2), and _gather_global_per_expert_amax then
fails on `amax.view(num_gemms)` because amax.numel() != num_gemms.

The SAVE path was fine because mtq.quantize routes through
iter_weights_for_calibration, which stacks the N expert weights into a
[num_gemms, out, in] tensor before calling weight_quantizer once; axis=0
reduction then sizes _amax as [num_gemms, 1, 1] correctly.

Mirror that stacking pattern in modelopt_post_restore so the restore path
lands on the same placeholder shape. The per-tensor path stays on the
legacy weight0 placeholder (already correct).

Surfaced by the OMNIML-4998 AC-2 GPU smoke after the OMNIML-5029 launcher
fix unblocked the cluster harness.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…af quantizer

Roots cause: the dist-checkpoint tensor and the _pytorch_state_metadata for
per-expert weight_quantizer._amax record different shapes. The tensor is saved
flat as [num_gemms_global] (the legacy _process_quantizer_amax flatten),
while the metadata records the live buffer shape [num_gemms, 1, 1] (axis=0
calibration with keepdims). On restore, the modelopt-state path registers
_amax as [num_gemms, 1, 1] first; the subsequent load_state_dict then tries
to copy the flat tensor into that buffer and trips the strict size check.

Parent-level approaches (TEGroupedLinear._load_from_state_dict mutating the
state_dict in place, or _process_quantizer_amax emitting a multi-dim save
shape) introduced rank-asymmetric collective ordering — either because the
mutation conditioned on per-rank buffer-registration timing or because the
save-shape change perturbed the ShardedTensor metadata that
determine_global_metadata exchanges via all_gather_object across ranks.

Fix at the leaf instead. _QuantTEGroupedLinear._setup now registers a
load_state_dict pre-hook on weight_quantizer that reshapes the loaded
_amax to the live buffer shape when shapes differ but numel matches. The
hook is:
  - Leaf-level (runs at the TensorQuantizer module PyTorch actually iterates).
  - Rank-local (touches only local state_dict and buffer shape).
  - Side-effect-free (no collectives, no conditional code paths beyond the
    presence-of-buffer check, which is symmetric across ranks once
    set_extra_state has run on all ranks at the end of the first
    load_state_dict pass).

Pairs with the modelopt_post_restore stacking commit (8fd0d53) so the live
buffer shape on a fresh restore is [num_gemms, 1, 1] — which the hook then
matches the loaded flat tensor to.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
For column-parallel TEGroupedLinear in per-expert mode (axis=0), max_calibrate
was NOT syncing _amax across the TP group. The existing axes_for_sync list
[None, -1] was designed for the classic case where axis=0 means "per-output-
channel" (TP-sharded; syncing would corrupt per-channel amax). For TEGrouped
in per-expert mode, axis=0 means "per-expert" (NOT TP-sharded under etp=1),
so it SHOULD be synced.

Symptom this fixes (caught by AC-2 GPU smoke on aws-cmh):
  - Each TP rank within an EP group ran calibration on bit-identical-ish
    expert weights (BF16/FP16 sums aren't bit-identical across ranks) and
    produced slightly divergent per-expert amax.
  - dist-checkpoint treats _amax as replicated (per `_get_shard_axis_dict`
    override) and saves only one rank's view (effectively tp_rank=0).
  - On restore, every rank loaded tp_rank=0's amax. model_ref (still
    holding its per-rank-different amax) then mismatched model_test on
    every TP rank except tp_rank=0, causing
    `assert torch.allclose(model_ref._amax, model_test._amax)` to raise on
    tp_rank=1 ranks. Failed ranks entered teardown's barrier; passing
    ranks entered the test's terminal barrier; both stuck at op 27 on
    different call sites, watchdog killed the job at 10-min mark.

Fix: detect TEGroupedLinear in per-expert mode (module has `weight0` attr
and `weight_quantizer.axis in (0, (0,))`) and extend axes_for_sync for the
column-parallel weight quantizer to [None, -1, 0]. The TEGroupedLinear
detection is by-attribute (weight0) rather than by-class to avoid a circular
import. Row-parallel already syncs axis=0, no change there.

Verified by `services/model-optimizer/test/omniml_4998_per_expert_amax_smoke.yaml`
on aws-cmh against nvcr.io/nvidia/nemo:25.11.nemotron_3_nano with OMNIML-5029's
megatron_install_path bind-mount: test_te_grouped_per_expert_sharded_state_dict
PASSED on all 4 ranks (tp=2 ep=2 etp=1 num_experts=4).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
@hychiang-git hychiang-git changed the title Hungyuehc/omniml 4998 umbrella [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear Jun 11, 2026
@hychiang-git

Copy link
Copy Markdown
Contributor Author

Related to PR #1550 — same scope (per-expert weight quantization on TEGroupedLinear), alternative design:

Happy to discuss merging approaches. Cross-linking so reviewers on either PR have visibility into the other.

@hychiang-git hychiang-git marked this pull request as ready for review June 11, 2026 16:20
@hychiang-git hychiang-git requested a review from a team as a code owner June 11, 2026 16:20
@hychiang-git

Copy link
Copy Markdown
Contributor Author

/claude review

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 360-373: The function _weight_axes_for_sync currently checks
quantizer.axis in (0, (0,)) but only appends the integer 0 to base_axes, causing
tuple-form axes like (0,) to be missed by downstream exact-match checks; update
_weight_axes_for_sync to normalize the quantizer.axis (e.g., treat (0,) and 0
equivalently) or append both representations (0 and (0,)) when weight_quantizer
is a TensorQuantizer with axis 0, so that axes_for_sync contains the same form
as quantizer.axis and TP amax sync occurs correctly for per-expert (weight0)
modules.

In `@modelopt/torch/quantization/plugins/megatron.py`:
- Around line 787-796: The cached per-expert amax override in
_process_quantizer_amax is too broad because the check `k.endswith("_amax")`
also matches keys like `weight_quantizer._global_amax`; change the condition so
the cached override only applies to the exact per-weight amax key (e.g., match
"weight_quantizer._amax" or the full qualifier used for per-expert weight amax)
— update the if that tests k and "weight_quantizer" to require the precise
suffix or full key "weight_quantizer._amax" before assigning cached.view(...) to
quantizer_state_dict[k], leaving all other _amax keys to use v.view(v.numel()).

In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py`:
- Around line 699-712: The second rule in TE_GROUPED_PER_EXPERT_CFG mistakenly
places "enable": True inside the nested "cfg" so the expert weight quantizers
remain disabled; update the rule matching "*experts.linear_fc*.weight_quantizer"
in TE_GROUPED_PER_EXPERT_CFG to set "enable": True at the top level of that rule
(alongside "quantizer_name" and "cfg") rather than inside "cfg", ensuring the
per-expert axis=0 quantization path in _QuantTEGroupedLinear is actually
enabled.

In `@tools/launcher/core.py`:
- Around line 317-319: The new parameter megatron_src_path was inserted before
experiment_title in run_jobs, breaking existing positional calls that pass
modelopt_src_path, experiment_title; restore the positional contract by moving
megatron_src_path after experiment_title in the run_jobs signature (or
alternatively make megatron_src_path keyword-only by adding a * before it), and
keep default None for modelopt_src_path and megatron_src_path; update any
callers if you choose the keyword-only approach.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 0cb44838-5056-48d6-9cf4-dfa316d3d963

📥 Commits

Reviewing files that changed from the base of the PR and between 48767a0 and 82b4979.

📒 Files selected for processing (10)
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/megatron.py
  • modelopt/torch/quantization/plugins/transformer_engine.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
  • tools/launcher/core.py
  • tools/launcher/slurm_config.py
  • tools/launcher/tests/test_docker_execution.py
  • tools/launcher/tests/test_slurm_config.py
  • tools/launcher/tests/test_slurm_executor.py

Comment on lines +360 to +373
def _weight_axes_for_sync(module, base_axes):
# For column-parallel TEGroupedLinear in per-expert mode (axis=0), the
# expert weights are NOT TP-sharded (etp=1 case) — axis=0 indexes experts,
# not output channels. Per-rank reductions still produce slightly
# divergent amax across TP (BF16/FP16 sums are not bit-identical across
# ranks even with the same input), and dist-checkpoint save treats
# _amax as replicated and captures only one rank's view. Without this
# sync, model_ref's per-rank-different amax mismatches the loaded
# model_test on every TP rank except the one whose value was saved.
if hasattr(module, "weight0"):
quantizer = getattr(module, "weight_quantizer", None)
if isinstance(quantizer, TensorQuantizer) and quantizer.axis in (0, (0,)):
return list(base_axes) + [0]
return base_axes

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Include (0,) in the TP sync allowlist.

_weight_axes_for_sync() detects both 0 and (0,), but it only appends 0. The downstream check is an exact quantizer.axis in axes_for_sync, so tuple-normalized per-expert axes still skip TP amax sync and can diverge across ranks.

Proposed fix
         if hasattr(module, "weight0"):
             quantizer = getattr(module, "weight_quantizer", None)
             if isinstance(quantizer, TensorQuantizer) and quantizer.axis in (0, (0,)):
-                return list(base_axes) + [0]
+                return [*base_axes, 0, (0,)]
         return base_axes
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 360 - 373, The
function _weight_axes_for_sync currently checks quantizer.axis in (0, (0,)) but
only appends the integer 0 to base_axes, causing tuple-form axes like (0,) to be
missed by downstream exact-match checks; update _weight_axes_for_sync to
normalize the quantizer.axis (e.g., treat (0,) and 0 equivalently) or append
both representations (0 and (0,)) when weight_quantizer is a TensorQuantizer
with axis 0, so that axes_for_sync contains the same form as quantizer.axis and
TP amax sync occurs correctly for per-expert (weight0) modules.

Comment on lines 787 to +796
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
quantizer_state_dict[k] = v.view(-1)
# Per-expert weight amax: emit the gathered global tensor cached by
# sharded_state_dict's pre-pass. No collective fires here, so this
# call is safe to run inside Megatron's save traversal.
# Per-tensor (or non-weight) amax: just reshape the local value.
cached = getattr(self, "_cached_global_per_expert_amax", None)
if cached is not None and "weight_quantizer" in k and k.endswith("_amax"):
quantizer_state_dict[k] = cached.view(cached.numel())
return
quantizer_state_dict[k] = v.view(v.numel())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Scope the cached override to weight_quantizer._amax only.

k.endswith("_amax") also matches weight_quantizer._global_amax. In per-expert NVFP4/static configs that would write the gathered per-expert vector into the scalar _global_amax slot, breaking checkpoint round-tripping for that quantizer state.

Proposed fix
-            if cached is not None and "weight_quantizer" in k and k.endswith("_amax"):
+            if cached is not None and k == "weight_quantizer._amax":
                 quantizer_state_dict[k] = cached.view(cached.numel())
                 return
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/plugins/megatron.py` around lines 787 - 796, The
cached per-expert amax override in _process_quantizer_amax is too broad because
the check `k.endswith("_amax")` also matches keys like
`weight_quantizer._global_amax`; change the condition so the cached override
only applies to the exact per-weight amax key (e.g., match
"weight_quantizer._amax" or the full qualifier used for per-expert weight amax)
— update the if that tests k and "weight_quantizer" to require the precise
suffix or full key "weight_quantizer._amax" before assigning cached.view(...) to
quantizer_state_dict[k], leaving all other _amax keys to use v.view(v.numel()).

Comment on lines +699 to +712
TE_GROUPED_PER_EXPERT_CFG = {
"algorithm": "max",
"quant_cfg": [
# Disable everything (attention, embeddings, non-MoE MLP) so the test
# isolates the per-expert path on TEGroupedMLP experts.
{"quantizer_name": "*", "enable": False},
# Re-enable axis=0 on TEGrouped MoE experts only -- this triggers the
# per-expert path inside _QuantTEGroupedLinear.
{
"quantizer_name": "*experts.linear_fc*.weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0, "enable": True},
},
],
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Re-enable the expert quantizers at the rule level.

This config disables everything first with {"quantizer_name": "*", "enable": False}, but the follow-up rule puts enable inside cfg. The rest of this file uses top-level enable for toggling, so this can leave the target expert weight quantizers disabled and make the new regression test pass without ever hitting the axis-0 path.

Proposed fix
         {
             "quantizer_name": "*experts.linear_fc*.weight_quantizer",
-            "cfg": {"num_bits": 8, "axis": 0, "enable": True},
+            "enable": True,
+            "cfg": {"num_bits": 8, "axis": 0},
         },
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py` around lines
699 - 712, The second rule in TE_GROUPED_PER_EXPERT_CFG mistakenly places
"enable": True inside the nested "cfg" so the expert weight quantizers remain
disabled; update the rule matching "*experts.linear_fc*.weight_quantizer" in
TE_GROUPED_PER_EXPERT_CFG to set "enable": True at the top level of that rule
(alongside "quantizer_name" and "cfg") rather than inside "cfg", ensuring the
per-expert axis=0 quantization path in _QuantTEGroupedLinear is actually
enabled.

Comment thread tools/launcher/core.py
Comment on lines 317 to 319
modelopt_src_path=None,
megatron_src_path=None,
experiment_title="cicd",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Signature insertion breaks existing positional call contract in run_jobs.

Adding megatron_src_path before experiment_title changes positional binding. The run_jobs call still passes ..., modelopt_src_path, experiment_title, so experiment_title is now consumed as megatron_src_path, producing an invalid Megatron mount source and ignoring the intended title.

💡 Proposed fix
 def build_docker_executor(
     hf_local,
     slurm_config,
     experiment_id,
     job_dir,
     task_name,
     packager,
     modelopt_src_path=None,
-    megatron_src_path=None,
     experiment_title="cicd",
+    megatron_src_path=None,
 ):
                 if hf_local is not None:
                     executor = build_docker_executor(
                         hf_local,
                         task.slurm_config,
                         exp._id,
                         job_dir,
                         task_name,
                         packager,
-                        modelopt_src_path,
-                        experiment_title,
+                        modelopt_src_path=modelopt_src_path,
+                        experiment_title=experiment_title,
                     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tools/launcher/core.py` around lines 317 - 319, The new parameter
megatron_src_path was inserted before experiment_title in run_jobs, breaking
existing positional calls that pass modelopt_src_path, experiment_title; restore
the positional contract by moving megatron_src_path after experiment_title in
the run_jobs signature (or alternatively make megatron_src_path keyword-only by
adding a * before it), and keep default None for modelopt_src_path and
megatron_src_path; update any callers if you choose the keyword-only approach.

@hychiang-git hychiang-git marked this pull request as draft June 11, 2026 16:44
… quantized weights in TEGroupedLinear forward

The forward unpack in te_grouped_quantized_linear_fn used
`q_stacked[i] for i in range(num_gemms)` to feed the N quantized expert
weights into TE's grouped linear. Each select records a SelectBackward
autograd node; backward then materializes an intermediate gradient
buffer per select before scattering its slice into q_stacked.grad. At
high N the intermediate-buffer allocations dominate backward latency by
orders of magnitude.

unbind(0) records a single UnbindBackward whose backward is a fused
torch.stack -- one allocation, one kernel. Same numerical math (STE
through the axis-0 fake-quant); only the autograd graph topology
changes.

Surfaced and measured under the OMNIML-5064 per-MoE-layer microbench
comparison study; landing here for the parent OMNIML-5072 unification.

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
… axis-0 fake-quant on TEGroupedLinear

Replaces the stack-then-quantize-then-unbind path in te_grouped_quantized_linear_fn
with a custom torch.autograd.Function that loops over the N expert weights
inside Function.forward, calling fake_quant_impl per expert against amax_vec[i].

Why: the previous path stacked the N expert weights into one [N, out, in]
tensor each forward. That stack is intra-HBM memcopy (no NCCL involved) and
scales with the per-GPU stacked weight bytes; it dominates per-step latency
at low per-GPU N with large per-expert weights. The autograd.Function
wrapper also collapses what would otherwise be N per-call Python dispatches
into a single Function entry, so the call site avoids both the stack
memcopy and the per-call Python overhead the legacy per-expert-module path
(weight_quantizer_i) paid.

Backward: clipping-aware STE per expert (dw = grad where |w| <= amax_vec[i],
else 0), matching FakeTensorQuantFunction's _fake_tensor_quant_backward
semantics. Single fused autograd node.

Calibration: handled inline as per-expert max accumulation into the
[N, 1, 1] _amax buffer (under no_grad). This bypasses modelopt's calibrator
interface, which is built around a single stacked input. The path is
coupled to max-calibration; _setup asserts the constraint up-front. Per-
expert support for other calibrators (histogram, percentile) is a separate
ticket.

OMNIML-5072 AC5 design; surfaced under the OMNIML-5064 comparison study.

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…nt via modelopt's _fake_tensor_quant_backward

OMNIML-5064 microbench surfaced a backward regression in 26307c7: the per-
expert STE backward was written as a plain-Python loop of `torch.where(
w.abs() <= amax, g, g.new_zeros(1))`, which dispatches three unfused eager-
mode kernels per expert plus one fresh tensor allocation per iteration.
Measured backward at Ultra mock-EP=16 was ~5x slower than the previous
unbind-based path.

Fix: delegate the per-expert STE math to modelopt's existing
@torch.jit.script-decorated _fake_tensor_quant_backward. It fuses the
abs/compare/where into a single kernel per expert and matches what
FakeTensorQuantFunction's own backward uses. Same numerical output;
collapses per-expert backward cost back to the expected single-fused-
kernel level.

Behaviour-equivalent change to 26307c7's backward; forward path is
unchanged.

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…single-GPU

Reverts:
  4c164f9 [OMNIML-5072] quantization: fuse STE backward in _GroupedAxis0FakeQuant ...
  26307c7 [OMNIML-5072] quantization: no-stack autograd.Function for per-expert ...

Measurement under OMNIML-5064's Ultra mock-EP=16 microbench (production-shape
per-GPU N=32 on a single B300):

  Cell                  fwd_us   bwd_us   step_us
  A (per-expert mods)    5,835      547    8,246    <- A baseline
  B unbind-only          6,380    1,308    9,549    <- d7ccf0a baseline
  B no-stack Function    6,060    6,155   14,073    <- AC5 regression

AC5's forward improvement was real (~5%) and matched A's fwd within tolerance.
But the per-expert backward kernels operate on N scattered HBM regions instead
of one contiguous stacked tensor, and on B300's high HBM bandwidth the
contiguous-tensor kernel access pattern dominates the (small) stack-allocation
cost. Net: AC5's backward is ~4.7x slower than the unbind-only path's
fused-on-stacked-tensor backward, and step time degrades by ~50%.

The unbind fix (d7ccf0a) stays — that's still a strict improvement over the
indexed-select pattern at high N. AC5's `_GroupedAxis0FakeQuant` is removed as
a failed experiment; the design rationale for revisiting it (a fused per-expert
CUDA/Triton kernel that achieves coalesced access across N scattered weights
in a single launch) lives in OMNIML-5064 ac2-results.md.

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant