Skip to content

[OMNIML-4944] peft: add lora_dtype field to PEFTAttributeConfig#1646

Draft
hychiang-git wants to merge 9 commits into
mainfrom
hungyuehc/omniml-4944
Draft

[OMNIML-4944] peft: add lora_dtype field to PEFTAttributeConfig#1646
hychiang-git wants to merge 9 commits into
mainfrom
hungyuehc/omniml-4944

Conversation

@hychiang-git

Copy link
Copy Markdown
Contributor

Adds an optional dtype string ('bf16' | 'fp16' | 'fp32', with long-form aliases) that lets a LoRA adapter pin its factor dtype independent of the wrapped layer.

Motivation: for low-bit MoE QAD, the base layer's storage dtype can be fake-quantized (e.g. uint4/int4 stand-ins) while the LoRA sidecar must stay BF16. Today _register_adapter_with_device copies the parent's dtype onto LoRA factors, which is wrong in that regime. lora_dtype gives plugins an explicit override knob; None (default) preserves today's inherit-from-parent behavior.

Validator normalizes long forms to canonical short forms and rejects unsupported strings.

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Adds an optional dtype string ('bf16' | 'fp16' | 'fp32', with long-form
aliases) that lets a LoRA adapter pin its factor dtype independent of
the wrapped layer.

Motivation: for low-bit MoE QAD, the base layer's storage dtype can be
fake-quantized (e.g. uint4/int4 stand-ins) while the LoRA sidecar must
stay BF16. Today _register_adapter_with_device copies the parent's dtype
onto LoRA factors, which is wrong in that regime. lora_dtype gives
plugins an explicit override knob; None (default) preserves today's
inherit-from-parent behavior.

Validator normalizes long forms to canonical short forms and rejects
unsupported strings.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 8, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 3e13266c-db2a-4b64-97ef-c55d4c24373b

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch hungyuehc/omniml-4944

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1646/

Built to branch gh-pages at 2026-06-08 18:03 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@hychiang-git hychiang-git force-pushed the hungyuehc/omniml-4944 branch from f6687c0 to 3d3dd3e Compare June 8, 2026 03:52
hychiang-git and others added 8 commits June 7, 2026 20:54
Adds modelopt/torch/peft/lora/plugins/megatron_moe.py:
- _StackedLoRAFactor: tiny nn.Module carrier holding one stacked
  Parameter of shape [num_experts, ...] for per-expert factors.
- _LoRATEGroupedBase: LoRAModule subclass with update_layer_lora() that
  allocates A:[E,r,in] / B:[E,out,r] honoring attr_config.lora_dtype, and
  a forward() that splits the input by tokens_per_expert and dispatches
  each expert's factors before summing into the base output.
- _LoRATEGroupedColumnParallelLinear / _LoRATEGroupedRowParallelLinear:
  thin registration markers binding the base class to TE's
  ColumnParallel / RowParallel grouped linears (the production path when
  moe_grouped_gemm=True, as in Nemotron-3 Hybrid).

forward() bypasses LoRAModule.forward via super(LoRAModule, self) because
LoRAModule's adapter loop calls lora_a(x) directly -- which would invoke
_StackedLoRAFactor.forward and raise. The stacked layout requires
per-expert dispatch keyed by tokens_per_expert.

Wires the new module into modelopt/torch/peft/lora/plugins/__init__.py
under the existing megatron import_plugin guard.

Plumbing only -- sharded_state_dict, SVDQuant init, and quant+LoRA
combined registration land in follow-up commits. Smoke test follows in
the next commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu_megatron/torch/peft/plugins/test_megatron_moe_lora.py
exercising the plugin landed in 3e8e868 on a tiny 2-layer MoE GPT
(num_moe_experts=4, moe_grouped_gemm=True, transformer_impl="transformer_engine"
-- the TEGroupedMLP production path).

Four assertions:
- Registration + zero-init identity: at least one _LoRATEGrouped*ParallelLinear
  exists after update_model; per-expert stacked factor shapes match
  [E, rank, in] / [E, out, rank]; zeros on B means lora_output == base_output;
  disable/enable adapters round-trips.
- Random init perturbs output: Kaiming on both A and B changes the model output.
- lora_dtype pin: lora_dtype="bf16" produces bfloat16 LoRA factor tensors,
  independent of the base layer's dtype.
- Gradient flow: with default freeze_base_model=True, LoRA params receive
  non-zero gradients and base params receive none.

Module-level pytest.mark.skipif(not HAVE_TE_GROUPED) lets the file run as a
no-op when the container's Transformer Engine is older than 1.9.0.dev0 --
expected to flag a container bump if all four tests come back SKIPPED.

Test not yet runtime-verified; pending SLURM submission via nmm-sandbox.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The smoke test's _test_gradient_flow was triggering a hard Megatron
assertion at moe_layer.py:616:

    if self.training and self.attn_tp_group.size() > 1 \
       and not self.config.sequence_parallel:
        raise ValueError("During training, performance may degrade if MoE
                          and tensor parallelism are enabled without also
                          enabling sequence parallelism.")

The MoE test model uses moe_grouped_gemm=True and the dist_workers fixture
defaults to tp_size=4 here; the _test_utils.torch.megatron.models factory
hardcodes sequence_parallel=False and does not surface it via config_kwargs,
so the only knob we can flip is self.training.

PyTorch autograd is mode-independent: loss.backward() in eval mode still
populates .grad on every parameter that participated in the forward, so the
test's actual claim -- LoRA factors receive non-zero gradients while base
weights remain frozen by default -- holds without entering train mode.

Smoke test run on aws-cmh (job 475427) confirms 4/4 green:
  test_registration_and_zero_init_identity PASSED
  test_random_init_perturbs_output         PASSED
  test_lora_dtype_pin                      PASSED
  test_gradient_flow                       PASSED  <- this commit

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a `lora_init_method` enum (Literal["kaiming_zeros", "svdquant"]) to
PEFTAttributeConfig and wires the "svdquant" branch into the TE-grouped
plugin's update_layer_lora. When chosen, per-expert factors are seeded
from a rank-r SVD of the quantization residual W_e - quant(W_e) using
the existing model_calib.svd helper:

  lora_a.weight[e] <- vt   ([rank, in_features])
  lora_b.weight[e] <- us   ([out_features, rank])

so that B_e @ A_e == us @ vt approximates the residual at init time.
Falls back to zero-init on both factors (with a warning) when no enabled
weight_quantizer is attached -- mtq.quantize() must run before
mtpeft.update_model() for SVDQuant init to be meaningful.

Plugin also registers the LoRA grouped class against the quantized
TE-grouped base classes (_MegatronTEGroupedColumnParallelLinear /
_MegatronTEGroupedRowParallelLinear) so the quantize -> LoRA flow
finds them on the grouped path. Mirrors peft/lora/plugins/megatron.py:
244-250 for the non-grouped variant.

Import order in peft/lora/plugins/__init__.py is flipped so megatron_moe
runs before megatron. The dynamic-module registry's class resolution
(modelopt/torch/opt/dynamic.py:_get_registered_nn_class) iterates
registrations in insertion order and picks the first class whose
`forward` identity-matches the target's forward. Both the TE-grouped and
non-grouped quant classes inherit `forward` from _QuantFunctionalMixin,
so they'd otherwise tie and the earlier-registered one (non-grouped)
would win -- causing _LoRAMegatronColumnParallelLinear.update_layer_lora
to be invoked on a TE-grouped target, where `self.input_size` doesn't
exist. With megatron_moe first, the grouped registration wins for
grouped targets; non-grouped targets are unaffected because their
issubclass check still excludes TE-grouped classes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two assertions to tests/gpu_megatron/torch/peft/plugins/
test_megatron_moe_lora.py, plus the per-tensor INT8 quant config they
use:

- test_svdquant_init_recovers_residual: runs mtq.quantize then
  mtpeft.update_model(svdquant), and verifies that for each
  TE-grouped expert e, B_e @ A_e is a non-trivial rank-r approximation
  of the quantization residual W_e - quant(W_e) (reconstruction error
  bounded by residual norm; factors not all zero).
- test_svdquant_init_no_quantizer_falls_back: skips mtq.quantize and
  asserts svdquant correctly degenerates to zero-init on both factors.

SVDQUANT_LORA_CFG restricts LoRA application to *experts*linear_fc*
patterns. Matches the ticket's "adapters placed one per up_proj /
down_proj across all MoE expert layers" scope, and sidesteps a
pre-existing modelopt bug where LoRA on TE plain linears (attention,
in the quantize -> LoRA flow) crashes on a missing `self.input_size`
attribute. Worth a follow-up ticket for the non-grouped TE+quant+LoRA
case, but out of scope here.

INT8_PER_TENSOR_QUANT_CFG mirrors NVFP4_DEFAULT_CONFIG's list-of-dicts
shape from tests/.../test_megatron_peft.py and stays per-tensor to
satisfy the TEGroupedLinear quantization restriction
(quantization/plugins/megatron.py:696).

All 6 smoke assertions pass on aws-cmh (SLURM job 476191, 4:02).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu_megatron/torch/peft/plugins/test_megatron_moe_lora.py::
test_quantize_then_lora_svdquant_state_dict_roundtrip. After
mtq.quantize -> mtpeft.update_model(svdquant), snapshots per-expert
LoRA factors, torch.save the state_dict to disk, zeroes the live
factors in place, then torch.load + load_state_dict(strict=True) and
asserts bitwise equality with the snapshot.

This is the Phase 2 sanity scaffold for OMNIML-4944: it validates the
"LoRA-on-quant-checkpoint" plumbing claim end-to-end at small scale,
without needing Megatron's sharded_state_dict (deferred to PR #3+).
Each dist_workers rank saves to its own tempdir to avoid collisions.

All 7 smoke assertions pass on aws-cmh (SLURM job 479919, 3:07).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Megatron dist-checkpoint integration for the per-expert stacked LoRA
factors on _LoRATEGroupedBase. Three additions:

- _factor_tp_axis(factor_name) on each subclass returns the TP-sharded
  axis of the 2-D per-expert slice (or None for replicated):
    column-parallel: lora_a None, lora_b axis 0 (out_features dim)
    row-parallel:    lora_a axis 1 (in_features dim), lora_b None
- sharded_state_dict() splits each stacked [E_local, ...] factor into
  per-global-expert entries keyed weight{global_idx}, with the EP dim
  encoded as a prepended sharded axis and TP via make_*_sharded_*.
  Mirrors TEGroupedLinear._sharded_state_dict_grouped (transformer_engine.py:1981).
- _load_from_state_dict() stacks the per-expert weight{i} entries back
  into a single stacked weight that matches the carrier's state_dict
  shape. Short-circuits when the stacked key is already present (plain
  torch.save/load path) so it doesn't pay the EP lookup cost there.

_ep_rank_size_tp_group() helper resolves (ep_rank, ep_size, tp_group)
preferring self._pg_collection / self._tp_group when available, falling
back to parallel_state.get_expert_model_parallel_* when not (the small
test factory builds linears via a path that doesn't reliably set
_pg_collection on every per-rank instance).

Unblocks real-Nemotron-3 checkpoint testing on dist-checkpoint format
(jennifchen's W4A16 dirs) and prepares the training-loop integration
that follows.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… factors

Adds test_quantize_then_lora_svdquant_dist_checkpoint_roundtrip mirroring
test_megatron_peft.py::test_mcore_quantize_then_lora_save_restore but for
the TE-grouped MoE LoRA path. Builds a quantized + LoRA model_ref, saves
via save_distributed_checkpoint + save_sharded_modelopt_state, restores
into a fresh model_test via restore_sharded_modelopt_state +
load_distributed_checkpoint, and asserts the post-LoRA forward output
matches model_ref's within tolerance.

Exercises the new _LoRATEGroupedBase.sharded_state_dict + _load_from_state_dict
hooks end-to-end on Megatron's actual dist-checkpoint format -- the same
format jennifchen's Nemotron-3-Nano-30B-A3B W4A16 checkpoints use.

All 8 smoke assertions pass on aws-cmh.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant