Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
1fee97c
add rabbit feedback
Fridah-nv Feb 6, 2026
3f717dd
minor
Fridah-nv Feb 13, 2026
971b168
tested perplexity
sugunav14 Feb 4, 2026
10c16ca
tested, revert later
sugunav14 Feb 9, 2026
364fd78
tested
sugunav14 Feb 10, 2026
5aee517
refactor
sugunav14 Feb 11, 2026
4b1e42f
Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQa…
realAsma Feb 6, 2026
6a15d0d
address reviewers feedback, delegate scaling factor calculation to NV…
Fridah-nv Feb 6, 2026
7b7146b
tested perplexity
sugunav14 Feb 4, 2026
40c14ef
tested exported checkpoints on 0211
sugunav14 Feb 12, 2026
7a1e006
tested nano v3
sugunav14 Feb 13, 2026
e6df379
added activation MSE logging
sugunav14 Feb 16, 2026
b81fed8
super v3 run
sugunav14 Feb 17, 2026
f3a9524
added activationmse logging helper
sugunav14 Feb 17, 2026
22e2b95
input amax sync added + tested gptq super sft checkpoint
sugunav14 Feb 19, 2026
10d21ba
checkpoints generated on 0223
sugunav14 Feb 23, 2026
188fa1d
tested perplexity
sugunav14 Feb 4, 2026
599227e
tested, revert later
sugunav14 Feb 9, 2026
60df0d8
tested
sugunav14 Feb 10, 2026
f88ba6e
initial cleanup
sugunav14 Feb 24, 2026
7b24cd3
cleanup
sugunav14 Feb 24, 2026
b17b917
removed stray prints
sugunav14 Feb 24, 2026
8ff8976
fix rebase issues
sugunav14 Mar 6, 2026
5815ce8
minor
sugunav14 Mar 6, 2026
b1f1434
tested e2e on qwen
sugunav14 Mar 6, 2026
df6b182
removed perplexity eval
sugunav14 Mar 6, 2026
75a08fe
update
sugunav14 Mar 6, 2026
9e58a6f
revert later
sugunav14 Mar 16, 2026
16086c7
minor update
sugunav14 Mar 19, 2026
9b47e77
update
sugunav14 Mar 18, 2026
4ec2433
gptq faster
sugunav14 Mar 18, 2026
2b0af3d
added metrics files, remove later
sugunav14 Mar 20, 2026
ee40b48
claude review
sugunav14 Mar 21, 2026
a175178
remove stray files
sugunav14 Mar 21, 2026
a948497
refactor
sugunav14 Mar 22, 2026
7e235b4
claude review + coderabbit review
sugunav14 Mar 23, 2026
d1498be
refactor
sugunav14 Mar 23, 2026
d8b1d93
stray changes removed
sugunav14 Mar 23, 2026
19fc0c2
Address PR comments
sugunav14 Mar 25, 2026
068e8a9
fixed circular import issue
sugunav14 Mar 25, 2026
2930b55
tested e2e on qwen3-8b
sugunav14 Mar 31, 2026
b35bc85
tested e2e on qwen3-8b
sugunav14 Mar 31, 2026
0f621cd
latest run with export
sugunav14 Apr 6, 2026
af59a55
clean up
sugunav14 Apr 6, 2026
d52d614
update
sugunav14 Apr 6, 2026
c705c24
updated e2e test
sugunav14 Apr 6, 2026
6d8f9b3
new perplexity eval
sugunav14 Apr 7, 2026
224f77d
update
sugunav14 Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 257 additions & 21 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@
save_expert_token_count_table,
)
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
from modelopt.torch.export.plugins.vllm_fakequant_hf import export_hf_vllm_fq_checkpoint
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
from modelopt.torch.quantization.metrics_backup import (
ActivationMSELogger,
compute_perplexity,
get_wikitext2,
)
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.utils.dataset_utils import (
Expand Down Expand Up @@ -241,6 +247,78 @@ def make_calib_dataloader(
return calib_dataloader, first_text_speech_dataset


def _make_mse_holdout_dataloader(args, tokenizer, device):
"""Create a hold-out dataloader for activation MSE, excluding calibration samples.

Uses content-based exclusion: reconstructs the exact calibration texts and
filters them from the hold-out set. This is robust to empty/filtered samples
that cause pure index-based skipping to under-count and produce overlap.
"""
from modelopt.torch.utils.dataset_utils import SUPPORTED_DATASET_CONFIG, get_dataset_samples

dataset_names, calib_sizes = args.dataset, args.calib_size
n_mse = args.activation_mse_max_samples

# 1. Reconstruct the exact calibration texts (same args as make_calib_dataloader).
calib_texts: set[str] = set()
for ds_name, cs in zip(dataset_names, calib_sizes):
calib_texts.update(get_dataset_samples(ds_name, cs, tokenizer=tokenizer))

# 2. Per-split skip for efficiency (avoids re-iterating calibration range).
skip_per_dataset = []
for ds_name, cs in zip(dataset_names, calib_sizes):
n_splits = len(
SUPPORTED_DATASET_CONFIG.get(ds_name, {}).get("config", {}).get("split", [None])
)
skip_per_dataset.append(-(-cs // max(n_splits, 1)))
skip = max(skip_per_dataset)

# 3. Distribute hold-out samples proportionally across datasets.
total_calib = sum(calib_sizes)
holdout_sizes = [max(1, int(n_mse * cs / total_calib)) for cs in calib_sizes]
holdout_sizes[-1] = n_mse - sum(holdout_sizes[:-1])

# 4. Collect hold-out texts, requesting extras to replace any filtered overlaps.
all_holdout_texts: list[str] = []
for ds_name, hs in zip(dataset_names, holdout_sizes):
texts = get_dataset_samples(
ds_name,
hs + len(calib_texts),
tokenizer=tokenizer,
skip_samples=skip,
)
filtered = [t for t in texts if t not in calib_texts][:hs]
all_holdout_texts.extend(filtered)

# 5. Tokenize and build dataloader (mirrors get_dataset_dataloader logic).
tok = copy.deepcopy(tokenizer)
batch_encoded = tok(
all_holdout_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
if device:
batch_encoded = batch_encoded.to(device)

class _InputIdsDataset(torch.utils.data.Dataset):
def __init__(self, input_ids):
self.input_ids = input_ids

def __getitem__(self, idx):
return {"input_ids": self.input_ids[idx]}

def __len__(self):
return len(self.input_ids)

return DataLoader(
_InputIdsDataset(batch_encoded["input_ids"]),
batch_size=args.batch_size,
shuffle=False,
)


def auto_quantize(
args: argparse.Namespace,
language_model: torch.nn.Module,
Expand Down Expand Up @@ -686,11 +764,17 @@ def export_quantized(
if mtp_layer_prefixes:
full_model._mtp_layer_prefixes = mtp_layer_prefixes

export_hf_checkpoint(
full_model,
export_dir=export_path,
extra_state_dict=mtp_state_dict,
)
if args.vllm_fakequant_export:
export_hf_vllm_fq_checkpoint(
full_model,
export_dir=export_path,
)
else:
export_hf_checkpoint(
full_model,
export_dir=export_path,
extra_state_dict=mtp_state_dict,
)

# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
Expand Down Expand Up @@ -747,7 +831,7 @@ def pre_quantize(
allow_fallback=False,
)
else:
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=2)

return preview_input_ids, generated_ids_before_ptq

Expand Down Expand Up @@ -786,7 +870,7 @@ def post_quantize(
pass
elif model_type != "llama4" and not is_nemotron_vl_model:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=2)
elif is_nemotron_vl_model and tokenizer is not None:
generated_ids_after_ptq = run_nemotron_vl_preview(
full_model,
Expand Down Expand Up @@ -910,6 +994,9 @@ def quantize_main(
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)

mse_logger = None
mse_data = None

if args.auto_quantize_bits:
assert len(args.qformat.split(",")) > 1, (
"Auto quantization needs multiple quantization format."
Expand Down Expand Up @@ -937,10 +1024,17 @@ def quantize_main(
"Plain quantization supports only one quantization format."
)

assert args.qformat in QUANT_CFG_CHOICES, (
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
if args.qformat in QUANT_CFG_CHOICES:
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
else:
# Fallback: resolve dynamically registered configs from the mtq namespace
# (e.g., PSX LUTS configs registered by modelopt-internal plugins).
quant_cfg = getattr(mtq, args.qformat, None)
assert quant_cfg is not None, (
f"Unsupported quantization format: {args.qformat}, "
f"not found in built-in choices {list(QUANT_CFG_CHOICES.keys())} "
f"or in the mtq namespace (check that the required plugin is installed)."
)

quant_cfg = build_quant_cfg(
args.qformat,
Expand Down Expand Up @@ -976,7 +1070,23 @@ def quantize_main(
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

if args.qformat in QUANT_CFG_CHOICES:
# Collect original (unquantized) activations before quantization modifies the model
if args.measure_activation_mse:
mse_logger = ActivationMSELogger(
max_samples=args.activation_mse_max_samples,
save_dir=args.activation_mse_save_dir,
)
mse_data = ActivationMSELogger.resolve_data(
input_path=args.activation_mse_input_path,
calib_dataloader=calib_dataloader,
tokenizer=tokenizer,
max_samples=args.activation_mse_max_samples,
max_length=args.calib_seq,
make_holdout_fn=lambda: _make_mse_holdout_dataloader(args, tokenizer, device),
)
mse_logger.collect(language_model, mse_data, phase="original")

if args.qformat in QUANT_CFG_CHOICES or hasattr(mtq, args.qformat):
mono_quantize(
args,
quant_cfg,
Expand All @@ -1002,15 +1112,67 @@ def quantize_main(
is_nemotron_vl_model,
first_text_speech_dataset,
)
export_quantized(
args,
full_model,
language_model,
model_type,
tokenizer,
default_padding_side,
default_pad_token,
)

if mse_logger is not None:
mse_logger.finish(language_model, mse_data)
del mse_logger, mse_data
torch.cuda.empty_cache()

if args.eval_perplexity and tokenizer is not None:
if args.fold_weights:
print("Folding weights before perplexity evaluation...")
mtq.fold_weight(language_model)
if args.eval_perplexity_input_path:
print(f"Loading perplexity eval data from {args.eval_perplexity_input_path}")
eval_data = torch.load(
args.eval_perplexity_input_path, map_location="cpu", weights_only=True
)
# Unbatch to [1, seq_len] per element for compute_perplexity batching
unbatched = []
for t in eval_data:
if not isinstance(t, torch.Tensor):
continue
if t.dim() == 1:
unbatched.append(t.unsqueeze(0))
elif t.shape[0] > 1:
unbatched.extend(t.unbind(0))
else:
unbatched.append(t)
eval_data = [t.unsqueeze(0) if t.dim() == 1 else t for t in unbatched]
print(f"Loaded {len(eval_data)} sequences from {args.eval_perplexity_input_path}")
label = args.eval_perplexity_input_path
else:
eval_data = get_wikitext2(tokenizer, args.eval_perplexity_seq_len)
label = "Wikitext-2"
ppl = compute_perplexity(full_model, eval_data)
print(f"{label} perplexity: {ppl:.2f}")

# Plugin-registered configs (e.g. PSX LUTS from modelopt-internal) are not exportable
# via the standard TRT-LLM / HF export paths. Fall back to save_pretrained().
if args.qformat not in QUANT_CFG_CHOICES and hasattr(mtq, args.qformat):
export_path = args.export_path
if args.vllm_fakequant_export:
print(f"Exporting vLLM fakequant checkpoint (bf16 weights + amax) to: {export_path}")
export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path)
else:
print(
f"qformat '{args.qformat}' is a plugin-registered config and is not exportable "
f"via the standard export pipeline. Saving with save_pretrained() instead."
)
full_model.save_pretrained(export_path)
if tokenizer is not None:
tokenizer.save_pretrained(export_path)
Comment on lines +1150 to +1164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This plugin-config export bypass skips the normal artifact cleanup.

By bypassing export_quantized(), this branch never restores the tokenizer's original padding settings and never copies custom model files/configs. That makes the saved artifact materially different from the standard export path, especially for trust_remote_code models.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 1088 - 1102, The plugin-config
export branch that uses full_model.save_pretrained() /
export_hf_vllm_fq_checkpoint (when args.qformat not in QUANT_CFG_CHOICES and
hasattr(mtq, args.qformat)) skips the post-export cleanup done by
export_quantized(): restore the tokenizer's original padding/token settings and
copy any custom model files/configs required for trust_remote_code models. Fix
by invoking the same post-export cleanup steps used by export_quantized() after
saving (or by calling a shared helper): restore the tokenizer's original
padding/pad_token/padding_side state and then copy over any custom/model config
files into export_path so the artifact matches the standard export path; ensure
this runs for both the export_hf_vllm_fq_checkpoint and
full_model.save_pretrained branches before tokenizer.save_pretrained.

print(f"Quantized model saved to: {export_path}")
else:
export_quantized(
args,
full_model,
language_model,
model_type,
tokenizer,
default_padding_side,
default_pad_token,
)


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -1064,6 +1226,16 @@ def parse_args() -> argparse.Namespace:
default=512,
)
parser.add_argument("--export_path", default="exported_model")
parser.add_argument(
"--vllm_fakequant_export",
action="store_true",
default=False,
help=(
"Export bf16 weights and amax values separately for vLLM fakequant serving. "
"Produces a standard HF checkpoint with GPTQ-adjusted weights plus a "
"quant_amax.pth file that can be loaded via AMAX_FILE_PATH in vllm_serve_fakequant.py."
),
)
parser.add_argument(
"--dataset",
help=(
Expand Down Expand Up @@ -1219,6 +1391,70 @@ def parse_args() -> argparse.Namespace:
),
)

parser.add_argument(
"--eval_perplexity",
action=argparse.BooleanOptionalAction,
default=False,
help="Evaluate Wikitext-2 perplexity after quantization (before export).",
)
parser.add_argument(
"--eval_perplexity_seq_len",
type=int,
default=2048,
help="Sequence length for perplexity evaluation (default: 2048).",
)
parser.add_argument(
"--eval_perplexity_input_path",
type=str,
default=None,
help=(
"Path to a .pt file containing pre-tokenized evaluation data "
"(List[Tensor], each [1, seq_len]) for perplexity evaluation. "
"When set, this data is used instead of WikiText-2. "
"Compatible with the .pt files produced by create_holdout_mse_inputs.py "
"or --activation_mse_input_path."
),
)
parser.add_argument(
"--measure_activation_mse",
action=argparse.BooleanOptionalAction,
default=False,
help="Measure per-layer activation MSE (original vs quantized) after quantization.",
)
parser.add_argument(
"--activation_mse_max_samples",
type=int,
default=16,
help="Max calibration samples for activation MSE (default: 16).",
)
parser.add_argument(
"--activation_mse_save_dir",
type=str,
default=None,
help="Directory to save activation MSE results. If not set, results are only printed.",
)
parser.add_argument(
"--activation_mse_input_path",
type=str,
default=None,
help=(
"Path to frozen MSE input data. Supports two formats:\n"
" .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes "
"with the current model's tokenizer; if not, decodes calibration data to text and saves.\n"
" .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; "
"if not, materializes from calibration data and saves."
),
)
parser.add_argument(
"--fold_weights",
action=argparse.BooleanOptionalAction,
default=False,
help=(
"Fold quantized weights before collecting activation MSE. "
"Speeds up the quantized forward pass by replacing weights in-place "
"and disabling fake-quant, but permanently mutates the weights."
),
)
args = parser.parse_args()
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
Expand Down
10 changes: 9 additions & 1 deletion examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def disable_compilation(model):
"quant_cfg": os.environ.get("QUANT_CFG", None),
"kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None),
"amax_file_path": os.environ.get("AMAX_FILE_PATH", None),
"skip_fold_weight": os.environ.get("SKIP_FOLD_WEIGHT", "0") == "1",
}


Expand Down Expand Up @@ -329,7 +330,14 @@ def calibrate_loop(model: Any = None) -> None:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
mtq.print_quant_summary(model)

mtq.fold_weight(model)
if quant_config["skip_fold_weight"]:
print("Skipping fold_weight (weights already quantized, e.g. from GPTQ export)")
for name, module in model.named_modules():
if name.endswith("weight_quantizer"):
module.disable()
else:
mtq.fold_weight(model)

for name, module in model.named_modules():
if name.endswith("weight_quantizer"):
assert not module.is_enabled, f"quantizer {name} is still enabled"
Expand Down
1 change: 1 addition & 0 deletions examples/vllm_serve/vllm_serve_fakequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"QUANT_CFG",
"AMAX_FILE_PATH",
"KV_QUANT_CFG",
"SKIP_FOLD_WEIGHT",
}

RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)
Expand Down
Loading