From 1fee97c3875c8f560996aa807ec50add61422c6f Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:02:12 +0000 Subject: [PATCH 01/48] add rabbit feedback Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc7..9f72fb9c64 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -563,10 +563,17 @@ def forward(self, input, *args, **kwargs): for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): +<<<<<<< HEAD module.hessian_helper = LocalHessianHelper(module, name) module.hessian_helper.setup() all_patched_modules.append((name, module)) if module.hessian_helper.is_enabled: +======= + module.local_hessian = LocalHessianHelper(module, name) + module.local_hessian.setup() + all_patched_modules.append((name, module)) + if module.local_hessian.is_enabled: +>>>>>>> e391ea1a (add rabbit feedback) weight_quantizers_info.append((name, module)) # Cache activations by running forward loop @@ -689,7 +696,11 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in all_patched_modules: +<<<<<<< HEAD module.hessian_helper.cleanup() +======= + module.local_hessian.cleanup() +>>>>>>> e391ea1a (add rabbit feedback) print_rank_0("local_hessian: Calibration complete.") From 3f717dd463ddfec487314a80e3bee112a9d9687c Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:12:30 -0800 Subject: [PATCH 02/48] minor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9f72fb9c64..ed57ea3fc7 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -563,17 +563,10 @@ def forward(self, input, *args, **kwargs): for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): -<<<<<<< HEAD module.hessian_helper = LocalHessianHelper(module, name) module.hessian_helper.setup() all_patched_modules.append((name, module)) if module.hessian_helper.is_enabled: -======= - module.local_hessian = LocalHessianHelper(module, name) - module.local_hessian.setup() - all_patched_modules.append((name, module)) - if module.local_hessian.is_enabled: ->>>>>>> e391ea1a (add rabbit feedback) weight_quantizers_info.append((name, module)) # Cache activations by running forward loop @@ -696,11 +689,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in all_patched_modules: -<<<<<<< HEAD module.hessian_helper.cleanup() -======= - module.local_hessian.cleanup() ->>>>>>> e391ea1a (add rabbit feedback) print_rank_0("local_hessian: Calibration complete.") From 971b1684c18e2f1d50d9bc829d7177586f861aa1 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 03/48] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 50 +++++++++++++++++++++++++++ modelopt/torch/quantization/mode.py | 14 ++++++++ modelopt/torch/utils/network.py | 1 + 3 files changed, 65 insertions(+) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index cf2336bf4a..1b6e81377c 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -269,6 +269,18 @@ "algorithm": "max", } +INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + INT4_AWQ_CFG = { "quant_cfg": { @@ -1346,6 +1358,44 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): ) +class GPTQConfig(QuantizeAlgorithmConfig): + """The config for GPTQ lite. + + GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + + GPTQ lite does not perform sequential quantization of layers. This means that the updated + activations are not used to process the next layer. + + The default values are taken from the official GPTQ implementation: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 + + Note: This feature is currently experimental and may not translate to improved accuracy as expected. + + + """ + + method: Literal["gptq"] = ModeloptField("gptq") + percdamp: float | None = ModeloptField( + default=0.01, + gt=0.0, + le=1.0, + title="Percentage damping factor.", + description="The percentage of average Hessian diagonal used for damping.", + ) + block_size: int | None = ModeloptField( + default=128, + title="Block size for GPTQ weight update.", + description="""The block size for GPTQ weight update, which must be a multiple of the + group_size used in the quantization.""", + ) + hessian_state_path: str | None = ModeloptField( + default=None, + title="Path to the Hessian state file.", + description="""The path to the Hessian state file. If hessian path exists, we load from + hessian file instead of recomputing them.""", + ) + + QuantizeQuantCfgType = dict[ str | Callable, QuantizerAttributeConfig diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index e08efece9a..88e93bb770 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + GPTQConfig, GPTQLiteConfig, LocalHessianCalibConfig, MaxCalibConfig, @@ -59,6 +60,7 @@ ) from .model_calib import ( awq, + gptq, gptq_lite, local_hessian_calibrate, max_calibrate, @@ -502,3 +504,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return GPTQLiteConfig _calib_func = gptq_lite + + +@CalibrateModeRegistry.register_mode +class GPTQModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for GPTQ calibration algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return GPTQConfig + + _calib_func = gptq diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375b..b07ca570c4 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,6 +46,7 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", + "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", From 10c16cabfcf945fb23c6047fb81c3bcd3e61e3bd Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:46:47 +0000 Subject: [PATCH 04/48] tested, revert later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5620ddf6a4..83a2516bd0 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -678,6 +678,82 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + if True: + # Disable quantizers + # mtq.fold_weight(full_model) + # print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + if True: + # mtq.fold_weight(full_model) + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + breakpoint() # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization From 364fd783f0c489a8376210a99d257a401617d409 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:41:46 +0000 Subject: [PATCH 05/48] tested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 17 +++-- modelopt/torch/quantization/config.py | 94 +++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 83a2516bd0..657acf1340 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -678,14 +678,16 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - if True: + if args.export_qdq_weights: # Disable quantizers - # mtq.fold_weight(full_model) - # print("Folded weights") + if "gptq" not in args.qformat: + mtq.fold_weight(full_model) + print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") mtq.disable_quantizer(full_model, "*") + if True: - # mtq.fold_weight(full_model) import os import torch.nn.functional as F @@ -753,7 +755,6 @@ def _compute_perplexity(model, data, batch_size: int = 1): ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - breakpoint() # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization @@ -1220,6 +1221,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--export_qdq_weights", + help=("Used for GPTQ weights as is without compressed weights for deployment."), + default=False, + action="store_true", + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 1b6e81377c..1e66b79b6c 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -281,6 +281,100 @@ }, } +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": { From 5aee5172a5ce7d412a887fc5f87ab8edea3b3642 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 11 Feb 2026 07:43:06 +0000 Subject: [PATCH 06/48] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 80 +++++----------------- 1 file changed, 16 insertions(+), 64 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc7..7b0a3df5e8 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1601,56 +1601,6 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv -def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): - """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. - - Args: - full_weight: The full weight tensor (needed for INT4 quantization) - block_start: Starting column index of the block - block_end: Ending column index of the block - h_inv: Hessian inverse - quantizer: The quantizer to apply - Returns: - quantized_block: Quantized weights for this block - losses: Quantization losses per element - errors: Accumulated errors for propagation - """ - # Extract the block we're working on - block_weight = full_weight[:, block_start:block_end] - block_hinv = h_inv[block_start:block_end, block_start:block_end] - block_size = block_end - block_start - - quantized_block = torch.zeros_like(block_weight) - losses = torch.zeros_like(block_weight) - errors = torch.zeros_like(block_weight) - - # We perform column-wise update for GPTQ within the block - group_size = 1 - - for group_start in range(0, block_size, group_size): - group_end = min(group_start + group_size, block_size) - group_cols = slice(group_start, group_end) - # Get current column and its Hessian inverse diagonal - weight_col = block_weight[:, group_cols] - hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) - - # Quantize using the full weight, then extract the columns we need - quantized_full = quantizer(full_weight) - quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] - quantized_block[:, group_cols] = quantized_cols - - # Compute quantization error and loss - error = (weight_col - quantized_cols) / hinv_diag - losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 - errors[:, group_cols] = error - - # Propagate error to remaining columns in block - block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] - full_weight[:, block_start:block_end] = block_weight - - return quantized_block, losses, errors - - def blockwise_weight_update(module, h, block_size, percdamp): """Update module weights using GPTQ-style blockwise quantization. @@ -1666,28 +1616,30 @@ def blockwise_weight_update(module, h, block_size, percdamp): # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) - # Initialize output tensors - quantized_weight = torch.zeros_like(weight) - losses = torch.zeros_like(weight) - # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) - - quantized_block, block_losses, block_errors = quantize_block( - weight, block_start, block_end, h_inv, module.weight_quantizer - ) - # Store results - quantized_weight[:, block_start:block_end] = quantized_block - losses[:, block_start:block_end] = block_losses + n_cols = block_end - block_start + wblk = weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] + + for i in range(n_cols): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = module.weight_quantizer(wblk) + weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err # Propagate errors to remaining weights - weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] + weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) # Print relative mse error - _print_relative_mse_error(quantized_weight, module.weight.float(), h, module.name) + _print_relative_mse_error(weight, module.weight.float(), h, module.name) # Update module weights - module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) + module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) def gptq_lite( From 4b1e42fb9944b4e574a33603cdeecee3370294ed Mon Sep 17 00:00:00 2001 From: realAsma <86726418+realAsma@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:47:36 -0800 Subject: [PATCH 07/48] Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQantizer, NVFP4MSECalibrator (#849) **Type of change:** ? **Overview:** ? ```python ``` - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No * **New Features** * Added NVFP4StaticQuantizer for improved 4-bit quantization with enhanced precision control * Introduced NVFP4MSECalibrator with flexible candidate generation for calibration optimization * **Improvements** * Optimized GPU kernels for Hopper+ graphics cards with better performance * Extended Triton support to broader GPU compatibility * Enhanced backward compatibility for restoring previously quantized models * **Tests** * Added comprehensive test coverage for new quantizers and calibration methods --------- Signed-off-by: realAsma Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/triton/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index def70e5914..6e8d4dba11 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -34,6 +34,10 @@ from .fp4_kernel import * from .fp8_kernel import * + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) + if torch.cuda.get_device_capability() >= (8, 9): + from .fp4_kernel_hopper import * + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): from .fp4_kernel_hopper import * From 6a15d0da0bf64365cbffcf0fb1a6d6d3ffb3f9da Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 22:56:02 +0000 Subject: [PATCH 08/48] address reviewers feedback, delegate scaling factor calculation to NVFP4QTensor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 4ceb51cd2c..b762757cb9 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -360,9 +360,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) + if not is_nvfp4_static: + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. From 7b7146bfc60e8c23c3e4200e0e93402fe08be635 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 09/48] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 87 ++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7b0a3df5e8..6fff65410c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1835,3 +1835,90 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() + + print_rank_0("Sequential calibration completed successfully") + + +@torch.no_grad() +def gptq( + layer: nn.Module, + inputs: list[tuple[tuple, dict]], + percdamp: float = 0.01, + block_size: int = 128, + **kwargs, +): + """GPTQ quantization - a GPTQ variant.""" + import time + + total_start = time.time() + + # Dictionary to store hessian matrices for all linear layers in this decoder + hessian_state = {} + + # Phase 1: Build tensor mapping for all quantized linear layers in this decoder layer + tensor_mapping = {} + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + in_features = module.weight.shape[-1] + tensor_mapping[name] = ((in_features, in_features), module.weight.device) + module.name = name # Attach name for easy access in hooks + + if not tensor_mapping: + print_rank_0("No quantized linear layers found in decoder layer, skipping GPTQ") + return + + # Initialize hessian state with zeros + for name, (shape, device) in tensor_mapping.items(): + hessian_state[name] = { + "hessian": torch.zeros(shape, dtype=torch.float32, device=device), + "n_samples": 0, + } + + # Phase 2: Register hooks to collect Hessians during forward passes + def hessian_hook(module, input, output): + """Hook to intercept activations and update hessian matrix.""" + state = hessian_state[module.name] + hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + + handles = [] + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + handles.append(module.register_forward_hook(hessian_hook)) + + # Run forward passes with the provided inputs to collect Hessians + hessian_start = time.time() + print_rank_0( + f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." + ) + for args, kwargs_input in inputs: + layer(*args, **kwargs_input) + + # Remove hooks after collecting Hessians + for handle in handles: + handle.remove() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() + print_rank_0("Updating weights using GPTQ algorithm...") + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + state = hessian_state[module.name] + hessian = state["hessian"].to(module.weight.device) + blockwise_weight_update(module, hessian, block_size, percdamp) + # Free memory + del hessian_state[module.name] + torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) From 40c14ef0fe599beeca6dd90d116622622b8c2676 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:54:46 +0000 Subject: [PATCH 10/48] tested exported checkpoints on 0211 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 69 ++++++++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 4 +- modelopt/torch/quantization/config.py | 22 +++++++ modelopt/torch/quantization/model_calib.py | 57 +++++++++++++++++- 4 files changed, 147 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 657acf1340..43c2a2e7e4 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -678,6 +678,75 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + + if True: + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_cache") + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) + if args.export_qdq_weights: # Disable quantizers if "gptq" not in args.qformat: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 14a12bcdf3..92c196e151 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -556,7 +556,7 @@ def _export_quantized_weight( )[0] quantized_weight = to_quantized_weight( - weight.to(dtype), + weight.to(torch.bfloat16), weight_scale, quantization_format, weight_scale_2, @@ -573,7 +573,7 @@ def _export_quantized_weight( ) quantized_weight = to_quantized_weight( - weight.to(dtype), + weight.to(torch.bfloat16), weight_scale, quantization_format, weight_scale_2, diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 1e66b79b6c..272e04642e 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -300,6 +300,28 @@ }, } +NVFP4_STATIC_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + NVFP4_STATIC_WO_GPTQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6fff65410c..2ac165df9c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,6 +15,7 @@ """Calibration utilities.""" +import contextlib import math import os import warnings @@ -1789,6 +1790,56 @@ def hessian_hook(module, input, output): print_rank_0("GPTQ-lite quantization completed successfully") +def _set_input_quantizers_calib_mode(layer: nn.Module): + """Set all input quantizers of a layer to calibration mode.""" + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + module._calibrator.reset() + module.disable_quant() + module.enable_calib() + + +def _set_input_quantizers_quant_mode(layer: nn.Module): + """Load fresh amaxes and restore all input quantizers of a layer to quant mode.""" + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + if module._calibrator.compute_amax() is not None: + module.load_calib_amax() + module.enable_quant() + module.disable_calib() + + +@contextlib.contextmanager +def _disable_input_quantizers(layer: nn.Module): + """Temporarily disable all enabled input quantizers in a layer.""" + enabled_quantizers = [] + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + ): + module.disable() + enabled_quantizers.append(module) + try: + yield + finally: + for module in enabled_quantizers: + module.enable() + + @torch.no_grad() def sequential_calibrate( model: nn.Module, @@ -1836,8 +1887,6 @@ def _layer_forward_loop(m, _inputs=layer_inputs): finally: input_getter._unpatch_all_layers() - print_rank_0("Sequential calibration completed successfully") - @torch.no_grad() def gptq( @@ -1877,8 +1926,10 @@ def gptq( # Phase 2: Register hooks to collect Hessians during forward passes def hessian_hook(module, input, output): """Hook to intercept activations and update hessian matrix.""" + if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: + inp = module.input_quantizer(input[0]) state = hessian_state[module.name] - hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} handles = [] From 7a1e0060ed8216da584928edbb802b7cd1bdca69 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 13 Feb 2026 19:53:25 +0000 Subject: [PATCH 11/48] tested nano v3 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 43c2a2e7e4..eaf036cf21 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -744,8 +744,7 @@ def _compute_perplexity(model, data, batch_size: int = 1): eval_data = _get_wikitext2(tokenizer, 2048) ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - print(f"Saving model to {args.export_path}") - full_model.save_pretrained(args.export_path) + breakpoint() if args.export_qdq_weights: # Disable quantizers @@ -753,8 +752,8 @@ def _compute_perplexity(model, data, batch_size: int = 1): mtq.fold_weight(full_model) print("Folded weights") - print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") - mtq.disable_quantizer(full_model, "*") + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) if True: import os From e6df37964be1ffe6734c2bc0ccabb19aaee6be76 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 16 Feb 2026 02:48:11 +0000 Subject: [PATCH 12/48] added activation MSE logging Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 48 ++++++++++++++++++++++ modelopt/torch/quantization/__init__.py | 1 + modelopt/torch/quantization/model_calib.py | 2 + 3 files changed, 51 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index eaf036cf21..9c1c58e5c9 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,6 +15,7 @@ import argparse import copy +import os import random import time import warnings @@ -568,6 +569,43 @@ def mono_quantize( else: calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + # Phase 1: Collect pre-quantization activations (batch_size=1 to save memory) + if getattr(args, "measure_activation_mse", False): + mse_max_samples = getattr(args, "activation_mse_max_samples", 16) + mse_save_dir = getattr(args, "activation_mse_save_dir", None) + mse_input_path = getattr(args, "activation_mse_input_path", None) + + # Materialize or load a frozen set of MSE inputs so that the exact + # same samples are used across runs and across codebases. + if mse_input_path and os.path.isfile(mse_input_path): + mse_data = mtq.ActivationMSELogger.load_data(mse_input_path) + else: + from torch.utils.data import DataLoader as _DataLoader + + mse_dataloader = _DataLoader(calib_dataloader.dataset, batch_size=1, shuffle=False) + if mse_input_path: + mse_data = mtq.ActivationMSELogger.materialize_data( + mse_dataloader, + mse_input_path, + max_samples=mse_max_samples, + ) + else: + # No path given -- materialize in memory only + mse_data = [] + for i, batch in enumerate(mse_dataloader): + if i >= mse_max_samples: + break + t = batch["input_ids"] if isinstance(batch, dict) else batch + mse_data.append(t.cpu()) + + mse_logger = mtq.ActivationMSELogger( + max_samples=mse_max_samples, + layer_filter=getattr(args, "activation_mse_layer_filter", None), + save_dir=mse_save_dir, + ) + print("\n--- Phase 1: Collecting pre-quantization activations ---") + mse_logger.collect(language_model, mse_data, phase="original") + if calibration_only: language_model = mtq.calibrate( language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop @@ -575,6 +613,16 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) + # Phase 2: Compute MSE against stored pre-quant activations + if getattr(args, "measure_activation_mse", False): + print("\n--- Phase 2: Computing per-layer activation MSE ---") + mse_logger.collect(language_model, mse_data, phase="quantized") + mse_logger.compute_mse() + print(mse_logger.summary()) + if mse_save_dir: + mse_logger.save() + del mse_logger, mse_data + # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 87dbf30bb5..757b844fb1 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -19,6 +19,7 @@ from . import mode, plugins, utils # Add methods to mtq namespace +from .activation_mse import ActivationMSELogger, collect_activations, measure_activation_mse from .compress import * from .config import * from .conversion import * diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 2ac165df9c..959e6117a5 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1928,6 +1928,8 @@ def hessian_hook(module, input, output): """Hook to intercept activations and update hessian matrix.""" if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: inp = module.input_quantizer(input[0]) + else: + inp = input[0] state = hessian_state[module.name] hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} From b81fed8c9a8b6cf354d16fcf7d8aef7283efe970 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 17 Feb 2026 06:07:59 +0000 Subject: [PATCH 13/48] super v3 run Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 72 +++++++++ modelopt/torch/quantization/model_calib.py | 137 +++++++++++++++++- .../nn/modules/tensor_quantizer.py | 13 +- 3 files changed, 212 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 272e04642e..e832c999a5 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -171,6 +171,54 @@ "*o_proj*": {"enable": False}, # Skip QKV Output Projection } +SUPER_NVFP4_CONSERVATIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": "max", +} + +SUPER_NVFP4_CONSERVATIVE_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + + INT8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, @@ -293,6 +341,9 @@ "enable": False, }, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", @@ -315,6 +366,9 @@ "enable": True, }, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", @@ -581,6 +635,9 @@ def _nvfp4_selective_quant_cfg( }, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "mse", @@ -1171,6 +1228,21 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) + checkpoint_every_n_layers: int | None = ModeloptField( + default=None, + title="Save intermediate checkpoint every N layers during sequential calibration.", + ) + + checkpoint_dir: str | None = ModeloptField( + default=None, + title="Directory for saving/loading intermediate GPTQ checkpoints.", + ) + + resume_from_layer: int = ModeloptField( + default=0, + title="Layer index to resume sequential calibration from (0 = start from beginning).", + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 959e6117a5..103a6dc453 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,8 @@ """Calibration utilities.""" import contextlib +import datetime +import json import math import os import warnings @@ -1520,7 +1522,13 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, module_name: str): +def _print_relative_mse_error( + q: torch.Tensor, + w: torch.Tensor, + h: torch.Tensor, + module_name: str, + n_samples: int | None = None, +): """Print relative mean squared error between quantized and original weights. Computes the Hessian-weighted relative MSE between quantized and original weights, @@ -1532,13 +1540,15 @@ def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, w (torch.Tensor): Original weight tensor h (torch.Tensor): Hessian matrix used for weighting the error module_name (str): Name of the module for logging purposes + n_samples (int | None): Number of Hessian samples (batches) used for this layer Note: Implementation adapted from the GPTQ repository: https://github.com/IST-DASLab/FP-Quant """ delta = q - w mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}") + suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" + print(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") def update_hessian(input, hessian, n_samples): @@ -1551,15 +1561,15 @@ def update_hessian(input, hessian, n_samples): Returns: Tuple of (updated_hessian, new_sample_count) """ - batch_size = input.shape[0] + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] # Incremental averaging: scale down old hessian hessian *= n_samples / (n_samples + batch_size) n_samples += batch_size # Compute outer product: H += (2/n_samples) * X @ X^T - # where X is the flattened input reshaped to (features, batch*seq) - input_flat = input.reshape(-1, input.shape[-1]).t().float() scaled_input = math.sqrt(2 / n_samples) * input_flat hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) @@ -1602,7 +1612,7 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv -def blockwise_weight_update(module, h, block_size, percdamp): +def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. Args: @@ -1610,6 +1620,7 @@ def blockwise_weight_update(module, h, block_size, percdamp): H: Hessian matrix (d x d) block_size: Size of blocks to process at once percdamp: Damping percentage for Hessian diagonal + n_samples: Number of Hessian samples for logging (optional) """ weight = module.weight.data.float().clone() _, num_cols = weight.shape @@ -1638,7 +1649,7 @@ def blockwise_weight_update(module, h, block_size, percdamp): weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) # Print relative mse error - _print_relative_mse_error(weight, module.weight.float(), h, module.name) + _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) # Update module weights module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) @@ -1840,11 +1851,117 @@ def _disable_input_quantizers(layer: nn.Module): module.enable() +def save_fake_checkpoint(model: nn.Module, output_dir: str) -> None: + """Save fake quant checkpoint using save_pretrained() (HuggingFace format). + + Args: + model: The quantized model to save. + output_dir: Directory to write the checkpoint into. + """ + from modelopt.torch.opt.conversion import ModeloptStateManager, modelopt_state + from modelopt.torch.quantization.conversion import quantizer_state as get_quantizer_state + + os.makedirs(output_dir, exist_ok=True) + + # Remove accelerate hooks before saving to avoid pickling errors in modelopt_state. + # Accelerate hooks contain local functions (closures like 'add_hook_to_module..new_forward') + # that can't be pickled. Even after removing hooks from modules, they may still be captured + # in closures within quantizer_state metadata when modelopt_state() calls update_last_state_before_save(). + try: + from accelerate.hooks import remove_hook_from_module + + remove_hook_from_module(model, recurse=True) + except ImportError: + pass + + # Save model weights first (without modelopt_state to avoid pickling error) + model.save_pretrained(output_dir, save_modelopt_state=False) + + # Manually save modelopt_state after removing hooks and rebuilding quantizer_state. + # We need to rebuild quantizer_state because hooks may have been captured in closures + # when quantizer_state() was called during update_last_state_before_save() inside modelopt_state(). + if ModeloptStateManager.is_converted(model): + modelopt_state_path = os.path.join(output_dir, "modelopt_state.pth") + state = modelopt_state(model) + + # Rebuild quantizer_state in metadata to remove any hook references captured in closures + if "modelopt_state_dict" in state and isinstance(state["modelopt_state_dict"], list): + cleaned_state_dict = [] + for entry in state["modelopt_state_dict"]: + if isinstance(entry, tuple) and len(entry) >= 2: + mode_str, state_dict_entry = entry[0], entry[1] + if isinstance(state_dict_entry, dict) and "metadata" in state_dict_entry: + # Rebuild quantizer_state after hooks are removed + cleaned_entry = state_dict_entry.copy() + cleaned_metadata = cleaned_entry["metadata"].copy() + cleaned_metadata["quantizer_state"] = get_quantizer_state(model) + cleaned_entry["metadata"] = cleaned_metadata + cleaned_state_dict.append((mode_str, cleaned_entry)) + else: + cleaned_state_dict.append(entry) + else: + cleaned_state_dict.append(entry) + state["modelopt_state_dict"] = cleaned_state_dict + + torch.save(state, modelopt_state_path) + print_rank_0(f"Saved ModelOpt state to {modelopt_state_path}") + + +def _save_gptq_checkpoint( + model: nn.Module, checkpoint_dir: str, last_layer_idx: int, total_layers: int +) -> None: + """Save intermediate GPTQ checkpoint with metadata for resume support. + + Saves accelerate hooks before calling save_fake_checkpoint (which removes them), + then re-attaches them so the model remains functional for subsequent layers. + """ + print_rank_0( + f"Saving GPTQ checkpoint after layer {last_layer_idx}/{total_layers - 1} to {checkpoint_dir}" + ) + + # Save accelerate hooks before save_fake_checkpoint removes them. + # We need to re-attach them after saving so the model keeps working. + saved_hooks = {} + for name, module in model.named_modules(): + if hasattr(module, "_hf_hook"): + saved_hooks[name] = module._hf_hook + + try: + save_fake_checkpoint(model, checkpoint_dir) + finally: + # Re-attach accelerate hooks so the model keeps working for remaining layers. + if saved_hooks: + try: + from accelerate.hooks import add_hook_to_module + + name_to_module = dict(model.named_modules()) + for name, hook in saved_hooks.items(): + if name in name_to_module: + add_hook_to_module(name_to_module[name], hook) + print_rank_0(f"Re-attached {len(saved_hooks)} accelerate hooks") + except ImportError: + pass + + # Save checkpoint metadata for resume support. + meta = { + "last_completed_layer": last_layer_idx, + "total_layers": total_layers, + "timestamp": datetime.datetime.now().isoformat(), + } + meta_path = os.path.join(checkpoint_dir, "gptq_checkpoint_meta.json") + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2) + print_rank_0(f"GPTQ checkpoint saved (layer {last_layer_idx}/{total_layers - 1})") + + @torch.no_grad() def sequential_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, + checkpoint_every_n_layers: int | None = None, + checkpoint_dir: str | None = None, + resume_from_layer: int = 0, **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm. @@ -1886,6 +2003,8 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() + + print_rank_0("Sequential calibration completed") @torch.no_grad() @@ -1961,7 +2080,9 @@ def hessian_hook(module, input, output): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: state = hessian_state[module.name] hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update(module, hessian, block_size, percdamp) + blockwise_weight_update( + module, hessian, block_size, percdamp, n_samples=state["n_samples"] + ) # Free memory del hessian_state[module.name] torch.cuda.empty_cache() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index ec2c3cfc55..4317c58609 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1331,10 +1331,19 @@ def global_amax(self, value): def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: + # Ensure amax/global_amax are on the same device as inputs. + # After from_pretrained with device_map, quantizer buffers may remain + # on CPU while model weights/activations are on GPU. + amax = self.amax + if amax.device != inputs.device: + amax = amax.to(inputs.device) + global_amax = self.global_amax + if global_amax is not None and global_amax.device != inputs.device: + global_amax = global_amax.to(inputs.device) return static_blockwise_fp4_fake_quant( inputs, - self.amax, - self.global_amax, # Can be None, will be computed internally + amax, + global_amax, # Can be None, will be computed internally True, # quantize_block_scales inputs.dtype, self._pass_through_bwd, From f3a9524d8e2e983d5caa5d2a9dd752d0cca50c6a Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 17 Feb 2026 21:17:38 +0000 Subject: [PATCH 14/48] added activationmse logging helper Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/activation_mse.py | 787 ++++++++++++++++++ 1 file changed, 787 insertions(+) create mode 100644 modelopt/torch/quantization/activation_mse.py diff --git a/modelopt/torch/quantization/activation_mse.py b/modelopt/torch/quantization/activation_mse.py new file mode 100644 index 0000000000..df90c84a3a --- /dev/null +++ b/modelopt/torch/quantization/activation_mse.py @@ -0,0 +1,787 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-layer activation MSE measurement for quantization analysis. + +This module provides utilities to measure per-linear-layer MSE between a model's +activations before and after quantization. Inspired by FP-Quant's two-phase approach: + +- **Phase 1** (before quantization): ``collect_activations()`` runs the model on + calibration data and stores per-layer outputs in CPU RAM. +- **Phase 2** (after quantization): ``measure_activation_mse()`` runs the quantized + model on the same data and computes MSE on-the-fly against the stored Phase 1 + outputs. Only running scalar accumulators are kept -- no second set of tensors + is stored. + +Typical usage in hf_ptq.py:: + + # Phase 1: before quantization + orig_acts = mtq.collect_activations(model, mse_dataloader, max_samples=16) + + # Quantize + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + # Phase 2: after quantization -- computes MSE incrementally + mse = mtq.measure_activation_mse(model, mse_dataloader, orig_acts, max_samples=16) +""" + +import contextlib +import fnmatch +import hashlib +import os +from collections.abc import Iterable +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from modelopt.torch.utils.network import get_decoder_layers + +__all__ = ["ActivationMSELogger", "collect_activations", "measure_activation_mse"] + + +def _tensor_from_output(out) -> torch.Tensor: + """Extract a single tensor from a layer's output (handles tuple returns).""" + if isinstance(out, torch.Tensor): + return out.detach() + return out[0].detach() + + +def _is_linear(module: nn.Module) -> bool: + """Check if a module is a linear layer (covers both nn.Linear and quantized linear).""" + return isinstance(module, nn.Linear) + + +def _matches_filter(name: str, layer_filter: str | None) -> bool: + """Check if a layer name matches the optional filter pattern (fnmatch-style).""" + if layer_filter is None: + return True + return fnmatch.fnmatch(name, layer_filter) + + +def _discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers within decoder blocks of the model. + + Uses get_decoder_layers() to find transformer blocks, then finds all linear + submodules within those blocks. Falls back to all linear layers in the model + if decoder blocks cannot be identified. + + Args: + model: The model to inspect. + layer_filter: Optional fnmatch pattern to select specific layers + (e.g., ``"*self_attn*"``). + + Returns: + Dict mapping full module path -> module reference. + """ + decoder_layers = get_decoder_layers(model) + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + # Build a reverse lookup: module id -> full name in model + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if _is_linear(sub_mod): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # Fallback: scan all modules + for name, module in model.named_modules(): + if _is_linear(module): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model.""" + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + + +@torch.no_grad() +def collect_activations( + model: nn.Module, + dataloader: Iterable, + max_samples: int | None = None, + layer_filter: str | None = None, +) -> dict[str, list[torch.Tensor]]: + """Collect per-linear-layer output activations into CPU memory (Phase 1). + + Registers forward hooks on linear layers within the model's decoder blocks, + runs calibration data through the model, and returns captured per-layer outputs. + + Args: + model: The model to collect activations from (typically pre-quantization). + dataloader: An iterable yielding batches (dicts with ``input_ids``, etc.). + Use batch_size=1 to minimize memory. + max_samples: Maximum number of batches to process. ``None`` means all. + layer_filter: Optional fnmatch pattern to restrict which layers are + collected (e.g., ``"*self_attn*"``). ``None`` means all linear layers + inside decoder blocks. + + Returns: + Dict mapping layer name to a list of output tensors (one per batch, on CPU). + """ + was_training = model.training + model.eval() + + # Discover target linear layers + targets = _discover_target_layers(model, layer_filter) + if not targets: + raise ValueError( + f"No linear layers found matching the given filter. layer_filter={layer_filter!r}" + ) + + print(f"Collecting activations for {len(targets)} layers...") + + # Storage: {layer_name: [tensor_per_batch, ...]} + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + # Register hooks + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(dataloader, desc="Collecting activations", leave=False): + if max_samples is not None and n_batches >= max_samples: + break + + captured.clear() + _run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + print(f"Collected {n_batches} samples across {len(targets)} layers") + return saved + + +@torch.no_grad() +def measure_activation_mse( + model: nn.Module, + dataloader: Iterable, + orig_activations: dict[str, list[torch.Tensor]], + max_samples: int | None = None, + layer_filter: str | None = None, +) -> dict[str, float]: + """Compute per-layer MSE between stored and live activations (Phase 2). + + Runs the (quantized) model on calibration data and computes MSE on-the-fly + against the pre-quantization activations stored by :func:`collect_activations`. + + Only scalar accumulators (sum of squared errors and element count) are kept + per layer -- no second set of activation tensors is stored. + + The MSE for each layer is computed as:: + + MSE = sum_over_all_elements((orig - quant) ^ 2) / total_elements + + Args: + model: The quantized model to measure. + dataloader: Same dataloader used for :func:`collect_activations` + (must yield batches in the same order). + orig_activations: Output of :func:`collect_activations` -- dict mapping + layer name to a list of pre-quantization output tensors. + max_samples: Maximum number of batches to process (should match Phase 1). + layer_filter: Optional fnmatch pattern (should match Phase 1). + + Returns: + Dict mapping layer name to its MSE value. + """ + was_training = model.training + model.eval() + + # Discover target layers on the (now-quantized) model + targets = _discover_target_layers(model, layer_filter) + + # Only measure layers that exist in both the model and orig_activations + common_keys = sorted(set(targets.keys()) & set(orig_activations.keys())) + if not common_keys: + raise ValueError( + "No matching layers between the quantized model and stored activations. " + "Ensure the same layer_filter is used for both phases." + ) + + skipped = set(orig_activations.keys()) - set(targets.keys()) + if skipped: + print(f"Warning: {len(skipped)} layers in orig_activations not found in model (skipped)") + + print(f"Computing activation MSE for {len(common_keys)} layers...") + + # Scalar accumulators + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + captured: dict[str, torch.Tensor] = {} + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + # Register hooks only on common layers + hooks = [targets[name].register_forward_hook(_make_hook(name)) for name in common_keys] + + try: + batch_idx = 0 + for batch in tqdm(dataloader, desc="Computing activation MSE", leave=False): + if max_samples is not None and batch_idx >= max_samples: + break + + captured.clear() + _run_batch(model, batch) + + for name in common_keys: + if name not in captured: + continue + if batch_idx >= len(orig_activations.get(name, [])): + continue + + o = orig_activations[name][batch_idx].float() + q = captured[name].float() + + if o.shape != q.shape: + print( + f"Warning: shape mismatch for {name} batch {batch_idx}: " + f"{o.shape} vs {q.shape}, skipping" + ) + continue + + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + batch_idx += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + mse = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in common_keys + } + + return mse + + +# --------------------------------------------------------------------------- +# Portable ActivationMSELogger class +# --------------------------------------------------------------------------- + + +def _portable_discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers in decoder blocks with a portable fallback chain. + + Strategy: + 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). + 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). + 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. + + Within each set of decoder blocks the function collects every ``nn.Linear`` + sub-module and optionally filters by *layer_filter* (fnmatch pattern). + """ + decoder_layers = None + + # 1. Try modelopt helper (may not exist when file is copied elsewhere) + with contextlib.suppress(Exception): + decoder_layers = get_decoder_layers(model) + + # 2. Try common HF / other patterns + if decoder_layers is None: + for attr_chain in ( + ("model", "layers"), + ("decoder", "layers"), + ("transformer", "h"), + ("backbone", "layers"), + ): + obj = model + try: + for attr in attr_chain: + obj = getattr(obj, attr) + if isinstance(obj, nn.ModuleList): + decoder_layers = obj + break + except AttributeError: + continue + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if isinstance(sub_mod, nn.Linear): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # 3. Fallback: all linear layers + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +class ActivationMSELogger: + """Portable activation MSE logger for comparing original vs quantized models. + + Works with both: + + - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` + or ``[B, seq_len]``, consumed via ``model(tensor)``. + - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): + ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. + + Guarantees same samples are used for both phases via SHA-256 hashing of + input tensors. Supports saving / loading all activations to disk for + later cross-codebase comparison. + + Example (ModelOpt -- DataLoader with dict batches):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model, dataloader, phase="original") + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + mse_logger.collect(model, dataloader, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + + Example (FP-Quant -- List[Tensor]):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model_orig, calibration_data, phase="original") + mse_logger.collect(model_quant, calibration_data, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + """ + + def __init__( + self, + max_samples: int = 16, + layer_filter: str | None = None, + save_dir: str | None = None, + ): + """Initialize the ActivationMSELogger. + + Args: + max_samples: Maximum number of calibration batches to process per phase. + layer_filter: Optional glob pattern to restrict which layers are tracked. + save_dir: Optional directory path for persisting activation data to disk. + """ + self.max_samples = max_samples + self.layer_filter = layer_filter + self.save_dir = save_dir + + # Per-phase state + self.original_activations: dict[str, list[torch.Tensor]] = {} + self.quantized_activations: dict[str, list[torch.Tensor]] = {} + self.input_hashes: list[str] = [] # hashes for "original" phase + self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase + + # Computed after both phases + self.mse_results: dict[str, float] | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def collect( + self, + model: nn.Module, + data: Iterable, + phase: str, + target_modules: dict[str, nn.Module] | None = None, + ) -> None: + """Collect per-linear-layer output activations for a given phase. + + Args: + model: The model to run (original or quantized). + data: An iterable of batches. Each batch can be: + + - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). + - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). + - ``list`` / ``tuple`` of tensors. + phase: ``"original"`` or ``"quantized"``. + target_modules: Optional explicit mapping of ``{name: nn.Module}`` + to attach hooks to. If *None*, layers are auto-discovered + via decoder-block scanning. + """ + if phase not in ("original", "quantized"): + raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") + + was_training = model.training + model.eval() + + # ----- layer discovery ----- + targets = ( + target_modules + if target_modules is not None + else (_portable_discover_target_layers(model, self.layer_filter)) + ) + if not targets: + raise ValueError( + "No linear layers found. Provide target_modules explicitly or " + f"check layer_filter={self.layer_filter!r}." + ) + + print( + f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " + f"max_samples={self.max_samples}" + ) + + # ----- storage ----- + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + hashes: list[str] = [] + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): + if self.max_samples is not None and n_batches >= self.max_samples: + break + + captured.clear() + self._run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + hashes.append(self._hash_batch(batch)) + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + # ----- store results on self ----- + if phase == "original": + self.original_activations = saved + self.input_hashes = hashes + else: + self.quantized_activations = saved + self.quant_input_hashes = hashes + # Verify sample consistency + if self.input_hashes: + self._verify_hashes() + + # Invalidate any previous MSE since we have new activations + self.mse_results = None + + print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") + + def compute_mse(self) -> dict[str, float]: + """Compute per-layer MSE between original and quantized activations. + + Returns: + Dict mapping layer name to its MSE value. + + Raises: + ValueError: If either phase has not been collected yet. + """ + if not self.original_activations: + raise ValueError( + "No original activations collected. Call collect(..., phase='original') first." + ) + if not self.quantized_activations: + raise ValueError( + "No quantized activations collected. Call collect(..., phase='quantized') first." + ) + + common_keys = sorted( + set(self.original_activations.keys()) & set(self.quantized_activations.keys()) + ) + if not common_keys: + raise ValueError( + "No matching layer names between original and quantized activations. " + "Ensure the same model architecture / layer_filter is used for both phases." + ) + + orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) + quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) + if orig_only: + print( + f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" + ) + if quant_only: + print( + f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" + ) + + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + for name in common_keys: + orig_list = self.original_activations[name] + quant_list = self.quantized_activations[name] + n = min(len(orig_list), len(quant_list)) + for i in range(n): + o = orig_list[i].float() + q = quant_list[i].float() + if o.shape != q.shape: + print( + f"[ActivationMSELogger] Warning: shape mismatch for {name} " + f"batch {i}: {o.shape} vs {q.shape}, skipping" + ) + continue + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + self.mse_results = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") + for key in common_keys + } + return self.mse_results + + def save(self, path: str | None = None) -> str: + """Save all state (activations, hashes, MSE) to disk via ``torch.save``. + + Args: + path: Explicit file path. If *None*, a timestamped file is created + inside ``self.save_dir`` (which must be set). + + Returns: + The path where the file was saved. + """ + if path is None: + if self.save_dir is None: + raise ValueError("Provide a path or set save_dir in the constructor.") + os.makedirs(self.save_dir, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") + + payload = { + "max_samples": self.max_samples, + "layer_filter": self.layer_filter, + "input_hashes": self.input_hashes, + "quant_input_hashes": self.quant_input_hashes, + "original_activations": self.original_activations, + "quantized_activations": self.quantized_activations, + "mse": self.mse_results, + } + torch.save(payload, path) + print(f"[ActivationMSELogger] Saved to {path}") + return path + + @classmethod + def load(cls, path: str) -> "ActivationMSELogger": + """Load a previously saved ``ActivationMSELogger`` from disk. + + Args: + path: Path to the ``.pt`` file created by :meth:`save`. + + Returns: + A new ``ActivationMSELogger`` instance with restored state. + """ + payload = torch.load(path, map_location="cpu", weights_only=False) + logger = cls( + max_samples=payload.get("max_samples", 16), + layer_filter=payload.get("layer_filter"), + ) + logger.original_activations = payload.get("original_activations", {}) + logger.quantized_activations = payload.get("quantized_activations", {}) + logger.input_hashes = payload.get("input_hashes", []) + logger.quant_input_hashes = payload.get("quant_input_hashes", []) + logger.mse_results = payload.get("mse") + print(f"[ActivationMSELogger] Loaded from {path}") + return logger + + def summary(self) -> str: + """Return a formatted string summarising per-layer MSE results. + + Computes MSE first if not already done. + """ + if self.mse_results is None: + self.compute_mse() + assert self.mse_results is not None + + lines = ["Per-layer activation MSE (original vs quantized):"] + lines.extend( + f" {key}: {self.mse_results[key]:.6e}" for key in sorted(self.mse_results.keys()) + ) + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Pre-materialized MSE data (cross-run / cross-codebase safety) + # ------------------------------------------------------------------ + + @staticmethod + def materialize_data( + data: Iterable, + path: str, + max_samples: int | None = None, + ) -> list[torch.Tensor]: + """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. + + Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a + single ``input_ids`` CPU tensor before saving. The resulting file is a + plain ``List[Tensor]`` that can be loaded in **any** codebase and passed + straight to :meth:`collect`. + + If *path* already exists it is **not** overwritten -- call + :meth:`load_data` instead. + + Args: + data: Iterable of batches (DataLoader, List[Tensor], etc.). + path: Destination ``.pt`` file path. + max_samples: How many batches to keep. ``None`` means all. + + Returns: + The materialised list of CPU tensors (same object that was saved). + """ + samples: list[torch.Tensor] = [] + for batch in data: + if max_samples is not None and len(samples) >= max_samples: + break + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + samples.append(t.cpu()) + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + torch.save(samples, path) + print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") + return samples + + @staticmethod + def load_data(path: str) -> list[torch.Tensor]: + """Load a previously materialised MSE input set. + + Args: + path: Path to the ``.pt`` file created by :meth:`materialize_data`. + + Returns: + ``List[Tensor]`` of input batches (on CPU). + """ + samples = torch.load(path, map_location="cpu", weights_only=True) + print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") + return samples + + # ------------------------------------------------------------------ + # Static / private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model (handles Tensor, dict, list/tuple). + + Automatically moves inputs to the model's device so that CPU-stored + materialized data works transparently with a CUDA model. + """ + device = next(model.parameters()).device + if isinstance(batch, dict): + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + model(**batch) + elif isinstance(batch, torch.Tensor): + model(batch.to(device)) + elif isinstance(batch, (list, tuple)): + batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) + model(*batch) + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + @staticmethod + def _hash_batch(batch) -> str: + """Compute SHA-256 hash of the primary input tensor in *batch*. + + - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). + - ``Tensor`` -> hashes the tensor directly. + - ``list/tuple`` -> hashes the first element. + """ + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] if batch else None + else: + return "" + + if t is None or not isinstance(t, torch.Tensor): + return "" + return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() + + def _verify_hashes(self) -> None: + """Compare input hashes between original and quantized phases.""" + n = min(len(self.input_hashes), len(self.quant_input_hashes)) + mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) + if mismatches: + print( + f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " + f"different input hashes between original and quantized phases. " + f"The same data may not have been used for both phases!" + ) + else: + print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") From 22e2b95479ad020a88d2dbb065e833c30693f972 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 19 Feb 2026 23:47:48 +0000 Subject: [PATCH 15/48] input amax sync added + tested gptq super sft checkpoint Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 31 ++++++++ tests/gpu/torch/quantization/test_gptq.py | 87 ++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 103a6dc453..5bb9739a39 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -135,6 +135,8 @@ def max_calibrate( for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax() + elif hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() if not distributed_sync: return @@ -1832,6 +1834,35 @@ def _set_input_quantizers_quant_mode(layer: nn.Module): module.disable_calib() +def _set_kv_quantizers_calib_mode(layer: nn.Module): + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + module._calibrator.reset() + module.disable_quant() + module.enable_calib() + + +def _set_kv_quantizers_quant_mode(layer: nn.Module): + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + if module._calibrator.compute_amax() is not None: + module.load_calib_amax() + module.enable_quant() + module.disable_calib() + + @contextlib.contextmanager def _disable_input_quantizers(layer: nn.Module): """Temporarily disable all enabled input quantizers in a layer.""" diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 0c60bcd007..c47b48b1e2 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -20,7 +20,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import _export_quantized_weight from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader RAND_SEED = 42 @@ -156,6 +158,91 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): assert torch.allclose(model.weight.data, q_dq_weight), "Weight should be equal" +def test_gptq_export_roundtrip(): + """Test that GPTQ export + dequantize produces weights matching in-memory QDQ.""" + torch.manual_seed(RAND_SEED) + dim = 128 + block_size = 4 + + # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers + model = torch.nn.Linear(dim, dim).to("cuda") + model.name = "linear" + original_weight = model.weight.data.clone() + input_tensor = torch.randn(2, 16, dim).to("cuda") + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(input_tensor)) + + # Restore original weight before GPTQ + model.weight.data = original_weight.clone() + + # Step 2: Perform GPTQ — compute Hessian and update weights + hessian = torch.zeros(dim, dim, dtype=torch.float32) + n_samples = 0 + hessian, n_samples = update_hessian(input_tensor, hessian, n_samples) + hessian = hessian.to("cuda") + + blockwise_weight_update(model, hessian, block_size, percdamp=0.1) + + # Save the QDQ reference from the quantizer applied to GPTQ'd weights + gptq_weight_shape = model.weight.data.shape + gptq_weight_dtype = model.weight.data.dtype + qdq_ref = model.weight.data.clone() + + # Step 3: Export — converts weight to packed NVFP4 and registers scale buffers + _export_quantized_weight(model, torch.bfloat16) + + # Verify export produced the expected buffers + assert hasattr(model, "weight_scale"), "Export should register weight_scale buffer" + assert hasattr(model, "weight_scale_2"), "Export should register weight_scale_2 buffer" + + # Step 4: Dequantize the exported packed weight and compare with QDQ reference + packed_weight = model.weight.data + weight_scale = model.weight_scale + weight_scale_2 = model.weight_scale_2 + + nvfp4_qtensor = NVFP4QTensor(gptq_weight_shape, gptq_weight_dtype, packed_weight) + deq_weight = nvfp4_qtensor.dequantize( + dtype=torch.bfloat16, + scale=weight_scale, + double_scale=weight_scale_2, + block_sizes={-1: 16}, + ) + + assert deq_weight.shape == qdq_ref.shape, ( + f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" + ) + diff = (deq_weight - qdq_ref.to(torch.bfloat16)).abs() + max_diff = diff.max().item() + max_diff_idx = diff.argmax().item() + max_diff_row = max_diff_idx // deq_weight.shape[1] + max_diff_col = max_diff_idx % deq_weight.shape[1] + num_mismatched = (diff > 1e-3).sum().item() + total_elements = diff.numel() + + print("\n--- Diff Stats ---") + print(f" Max diff: {max_diff}") + print(f" Mean diff: {diff.mean().item()}") + print(f" Median diff: {diff.median().item()}") + print(f" Std diff: {diff.std().item()}") + print( + f" Mismatched (>1e-3): {num_mismatched}/{total_elements} " + f"({100 * num_mismatched / total_elements:.2f}%)" + ) + print( + f" Max diff at [{max_diff_row}, {max_diff_col}]: " + f"deq={deq_weight[max_diff_row, max_diff_col].item()}, " + f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()}" + ) + + assert torch.allclose(deq_weight, qdq_ref.to(torch.bfloat16), atol=1e-2), ( + f"Dequantized weight does not match QDQ reference. " + f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}] " + f"(deq={deq_weight[max_diff_row, max_diff_col].item()}, " + f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()})" + ) + + @pytest.mark.parametrize( "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] ) From 10d21baecf3acd4e85ce625a6d1549483a9ccc90 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:24:34 +0000 Subject: [PATCH 16/48] checkpoints generated on 0223 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 1 - modelopt/torch/quantization/config.py | 22 ++++++--- modelopt/torch/quantization/model_calib.py | 57 ++++++++++++++++------ 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 9c1c58e5c9..ec587f33ec 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -871,7 +871,6 @@ def _compute_perplexity(model, data, batch_size: int = 1): ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e832c999a5..8e8e4f98b8 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -156,12 +156,23 @@ "*mlp.gate.*": {"enable": False}, # Skip the MOE router "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d + "*mixer.conv1d*": {"enable": False}, "*output_layer*": {"enable": False}, "output.*": {"enable": False}, "default": {"enable": False}, } +super_disabled_quantizer_cfg = { + "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE + "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE + "*q_proj*": {"enable": False}, # Skip QKV Linear + "*k_proj*": {"enable": False}, # Skip QKV Linear + "*v_proj*": {"enable": False}, # Skip QKV Linear + "*o_proj*": {"enable": False}, # Skip Output Linear + "*mtp*": {"enable": False}, # Skip MTP layers +} + + _mamba_moe_disabled_quantizer_cfg = { "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE @@ -186,7 +197,7 @@ "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -208,7 +219,7 @@ "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -341,7 +352,7 @@ "enable": False, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + # **_mamba_moe_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -366,9 +377,6 @@ "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5bb9739a39..7971ce59f6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -138,6 +138,26 @@ def max_calibrate( elif hasattr(module, "sync_moe_local_experts_amax"): module.sync_moe_local_experts_amax() + for name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + # Get the initial amax from max calibration + initial_amax = module._amax.clone().detach() + + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + + if is_nvfp4_static: + # Compute and set global_amax + global_amax = reduce_amax(initial_amax, axis=None) + + # Convert to NVFP4StaticQuantizer in-place + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + if not distributed_sync: return @@ -342,6 +362,7 @@ def mse_calibrate( if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator + print("mse_calibrate: Replacing calibrator with NVFP4MSECalibrator") module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, @@ -628,6 +649,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): error_func = helper.get_error_func() if fp8_scale_sweep and is_nvfp4_static: + print("local_hessian_calibrate: Replacing calibrator with NVFP4MSECalibrator") weight_quantizer._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, @@ -2073,21 +2095,26 @@ def gptq( "n_samples": 0, } - # Phase 2: Register hooks to collect Hessians during forward passes - def hessian_hook(module, input, output): - """Hook to intercept activations and update hessian matrix.""" - if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: - inp = module.input_quantizer(input[0]) - else: - inp = input[0] - state = hessian_state[module.name] - hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) - hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + # Phase 2: Patch forwards to collect Hessians (similar to local_hessian_calibrate) + def _make_hessian_forward(module_name): + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + state = hessian_state[module_name] + hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) + hessian_state[module_name] = {"hessian": hessian, "n_samples": n_samples} + + self.weight_quantizer.disable() + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + return hessian_forward - handles = [] + patched_modules = [] for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - handles.append(module.register_forward_hook(hessian_hook)) + bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") + patched_modules.append(module) # Run forward passes with the provided inputs to collect Hessians hessian_start = time.time() @@ -2097,9 +2124,9 @@ def hessian_hook(module, input, output): for args, kwargs_input in inputs: layer(*args, **kwargs_input) - # Remove hooks after collecting Hessians - for handle in handles: - handle.remove() + # Unpatch forwards + for module in patched_modules: + unpatch_forward_method(module, "_forward_no_gptq_hessian") torch.cuda.synchronize() if torch.cuda.is_available() else None hessian_time = time.time() - hessian_start From 188fa1da11869cc732252c4d42c62734e8cd6068 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 17/48] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 134 -------------- modelopt/torch/quantization/mode.py | 2 + modelopt/torch/quantization/model_calib.py | 205 +-------------------- 3 files changed, 5 insertions(+), 336 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 8e8e4f98b8..92ae76bc34 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -340,125 +340,6 @@ }, } -NVFP4_STATIC_WO_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - # **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "max", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_DYNAMIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} INT4_AWQ_CFG = { "quant_cfg": { @@ -1236,21 +1117,6 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - checkpoint_every_n_layers: int | None = ModeloptField( - default=None, - title="Save intermediate checkpoint every N layers during sequential calibration.", - ) - - checkpoint_dir: str | None = ModeloptField( - default=None, - title="Directory for saving/loading intermediate GPTQ checkpoints.", - ) - - resume_from_layer: int = ModeloptField( - default=0, - title="Layer index to resume sequential calibration from (0 = start from beginning).", - ) - class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 88e93bb770..efc66ffa94 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -255,6 +255,8 @@ def wrapped_calib_func( else: # Direct calibration (existing behavior) func(model, forward_loop=forward_loop, **kwargs) + else: + raise ValueError(f"No calibration function provided for method: {method}") # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7971ce59f6..100c749234 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1825,196 +1825,11 @@ def hessian_hook(module, input, output): print_rank_0("GPTQ-lite quantization completed successfully") -def _set_input_quantizers_calib_mode(layer: nn.Module): - """Set all input quantizers of a layer to calibration mode.""" - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - module._calibrator.reset() - module.disable_quant() - module.enable_calib() - - -def _set_input_quantizers_quant_mode(layer: nn.Module): - """Load fresh amaxes and restore all input quantizers of a layer to quant mode.""" - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - if module._calibrator.compute_amax() is not None: - module.load_calib_amax() - module.enable_quant() - module.disable_calib() - - -def _set_kv_quantizers_calib_mode(layer: nn.Module): - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - module._calibrator.reset() - module.disable_quant() - module.enable_calib() - - -def _set_kv_quantizers_quant_mode(layer: nn.Module): - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - if module._calibrator.compute_amax() is not None: - module.load_calib_amax() - module.enable_quant() - module.disable_calib() - - -@contextlib.contextmanager -def _disable_input_quantizers(layer: nn.Module): - """Temporarily disable all enabled input quantizers in a layer.""" - enabled_quantizers = [] - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - ): - module.disable() - enabled_quantizers.append(module) - try: - yield - finally: - for module in enabled_quantizers: - module.enable() - - -def save_fake_checkpoint(model: nn.Module, output_dir: str) -> None: - """Save fake quant checkpoint using save_pretrained() (HuggingFace format). - - Args: - model: The quantized model to save. - output_dir: Directory to write the checkpoint into. - """ - from modelopt.torch.opt.conversion import ModeloptStateManager, modelopt_state - from modelopt.torch.quantization.conversion import quantizer_state as get_quantizer_state - - os.makedirs(output_dir, exist_ok=True) - - # Remove accelerate hooks before saving to avoid pickling errors in modelopt_state. - # Accelerate hooks contain local functions (closures like 'add_hook_to_module..new_forward') - # that can't be pickled. Even after removing hooks from modules, they may still be captured - # in closures within quantizer_state metadata when modelopt_state() calls update_last_state_before_save(). - try: - from accelerate.hooks import remove_hook_from_module - - remove_hook_from_module(model, recurse=True) - except ImportError: - pass - - # Save model weights first (without modelopt_state to avoid pickling error) - model.save_pretrained(output_dir, save_modelopt_state=False) - - # Manually save modelopt_state after removing hooks and rebuilding quantizer_state. - # We need to rebuild quantizer_state because hooks may have been captured in closures - # when quantizer_state() was called during update_last_state_before_save() inside modelopt_state(). - if ModeloptStateManager.is_converted(model): - modelopt_state_path = os.path.join(output_dir, "modelopt_state.pth") - state = modelopt_state(model) - - # Rebuild quantizer_state in metadata to remove any hook references captured in closures - if "modelopt_state_dict" in state and isinstance(state["modelopt_state_dict"], list): - cleaned_state_dict = [] - for entry in state["modelopt_state_dict"]: - if isinstance(entry, tuple) and len(entry) >= 2: - mode_str, state_dict_entry = entry[0], entry[1] - if isinstance(state_dict_entry, dict) and "metadata" in state_dict_entry: - # Rebuild quantizer_state after hooks are removed - cleaned_entry = state_dict_entry.copy() - cleaned_metadata = cleaned_entry["metadata"].copy() - cleaned_metadata["quantizer_state"] = get_quantizer_state(model) - cleaned_entry["metadata"] = cleaned_metadata - cleaned_state_dict.append((mode_str, cleaned_entry)) - else: - cleaned_state_dict.append(entry) - else: - cleaned_state_dict.append(entry) - state["modelopt_state_dict"] = cleaned_state_dict - - torch.save(state, modelopt_state_path) - print_rank_0(f"Saved ModelOpt state to {modelopt_state_path}") - - -def _save_gptq_checkpoint( - model: nn.Module, checkpoint_dir: str, last_layer_idx: int, total_layers: int -) -> None: - """Save intermediate GPTQ checkpoint with metadata for resume support. - - Saves accelerate hooks before calling save_fake_checkpoint (which removes them), - then re-attaches them so the model remains functional for subsequent layers. - """ - print_rank_0( - f"Saving GPTQ checkpoint after layer {last_layer_idx}/{total_layers - 1} to {checkpoint_dir}" - ) - - # Save accelerate hooks before save_fake_checkpoint removes them. - # We need to re-attach them after saving so the model keeps working. - saved_hooks = {} - for name, module in model.named_modules(): - if hasattr(module, "_hf_hook"): - saved_hooks[name] = module._hf_hook - - try: - save_fake_checkpoint(model, checkpoint_dir) - finally: - # Re-attach accelerate hooks so the model keeps working for remaining layers. - if saved_hooks: - try: - from accelerate.hooks import add_hook_to_module - - name_to_module = dict(model.named_modules()) - for name, hook in saved_hooks.items(): - if name in name_to_module: - add_hook_to_module(name_to_module[name], hook) - print_rank_0(f"Re-attached {len(saved_hooks)} accelerate hooks") - except ImportError: - pass - - # Save checkpoint metadata for resume support. - meta = { - "last_completed_layer": last_layer_idx, - "total_layers": total_layers, - "timestamp": datetime.datetime.now().isoformat(), - } - meta_path = os.path.join(checkpoint_dir, "gptq_checkpoint_meta.json") - with open(meta_path, "w") as f: - json.dump(meta, f, indent=2) - print_rank_0(f"GPTQ checkpoint saved (layer {last_layer_idx}/{total_layers - 1})") - - @torch.no_grad() def sequential_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, - checkpoint_every_n_layers: int | None = None, - checkpoint_dir: str | None = None, - resume_from_layer: int = 0, **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm. @@ -2064,14 +1879,14 @@ def _layer_forward_loop(m, _inputs=layer_inputs): def gptq( layer: nn.Module, inputs: list[tuple[tuple, dict]], + forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, **kwargs, ): """GPTQ quantization - a GPTQ variant.""" - import time - - total_start = time.time() + # Set weight amax and activation amax'es for the current layer using max_calibrate + max_calibrate(layer, forward_loop=forward_loop) # Dictionary to store hessian matrices for all linear layers in this decoder hessian_state = {} @@ -2117,7 +1932,6 @@ def hessian_forward(self, input, *args, **kwargs): patched_modules.append(module) # Run forward passes with the provided inputs to collect Hessians - hessian_start = time.time() print_rank_0( f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." ) @@ -2128,11 +1942,8 @@ def hessian_forward(self, input, *args, **kwargs): for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") - torch.cuda.synchronize() if torch.cuda.is_available() else None - hessian_time = time.time() - hessian_start # Phase 3: Update weights using computed Hessians (same as gptq_lite) - weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: @@ -2144,13 +1955,3 @@ def hessian_forward(self, input, *args, **kwargs): # Free memory del hessian_state[module.name] torch.cuda.empty_cache() - - torch.cuda.synchronize() if torch.cuda.is_available() else None - weight_update_time = time.time() - weight_update_start - - total_time = time.time() - total_start - print_rank_0( - f"GPTQ timing - Hessian: {hessian_time:.2f}s, " - f"Weight update: {weight_update_time:.2f}s, " - f"Total: {total_time:.2f}s" - ) From 599227eb9f27cc36522eab6d4424d612423bba2d Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:46:47 +0000 Subject: [PATCH 18/48] tested, revert later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ec587f33ec..84a84041b5 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -726,6 +726,82 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + if True: + # Disable quantizers + # mtq.fold_weight(full_model) + # print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + if True: + # mtq.fold_weight(full_model) + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + breakpoint() if True: import os From 60df0d8e70e93c5112711659bf3861c8580e67fc Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:41:46 +0000 Subject: [PATCH 19/48] tested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 220 -------------------------- modelopt/torch/quantization/config.py | 94 +++++++++++ 2 files changed, 94 insertions(+), 220 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 84a84041b5..18c0bb8d57 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -726,226 +726,6 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - if True: - # Disable quantizers - # mtq.fold_weight(full_model) - # print("Folded weights") - print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") - mtq.disable_quantizer(full_model, "*") - if True: - # mtq.fold_weight(full_model) - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), ".hf_cache" - ) - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - - breakpoint() - - if True: - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_cache") - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - breakpoint() - - if args.export_qdq_weights: - # Disable quantizers - if "gptq" not in args.qformat: - mtq.fold_weight(full_model) - print("Folded weights") - - print(f"Saving model to {args.export_path}") - full_model.save_pretrained(args.export_path) - - if True: - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), ".hf_cache" - ) - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 92ae76bc34..777406d18e 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -340,6 +340,100 @@ }, } +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": { From f88ba6ed35d27fda152c73aae40999a9b7c8ac8e Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:01:34 +0000 Subject: [PATCH 20/48] initial cleanup Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 53 ------------------ modelopt/torch/export/quant_utils.py | 62 ++++++---------------- modelopt/torch/export/unified_export_hf.py | 11 ++-- modelopt/torch/quantization/__init__.py | 1 - modelopt/torch/quantization/model_calib.py | 4 -- 5 files changed, 20 insertions(+), 111 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 18c0bb8d57..43adfcea7c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -569,43 +569,6 @@ def mono_quantize( else: calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - # Phase 1: Collect pre-quantization activations (batch_size=1 to save memory) - if getattr(args, "measure_activation_mse", False): - mse_max_samples = getattr(args, "activation_mse_max_samples", 16) - mse_save_dir = getattr(args, "activation_mse_save_dir", None) - mse_input_path = getattr(args, "activation_mse_input_path", None) - - # Materialize or load a frozen set of MSE inputs so that the exact - # same samples are used across runs and across codebases. - if mse_input_path and os.path.isfile(mse_input_path): - mse_data = mtq.ActivationMSELogger.load_data(mse_input_path) - else: - from torch.utils.data import DataLoader as _DataLoader - - mse_dataloader = _DataLoader(calib_dataloader.dataset, batch_size=1, shuffle=False) - if mse_input_path: - mse_data = mtq.ActivationMSELogger.materialize_data( - mse_dataloader, - mse_input_path, - max_samples=mse_max_samples, - ) - else: - # No path given -- materialize in memory only - mse_data = [] - for i, batch in enumerate(mse_dataloader): - if i >= mse_max_samples: - break - t = batch["input_ids"] if isinstance(batch, dict) else batch - mse_data.append(t.cpu()) - - mse_logger = mtq.ActivationMSELogger( - max_samples=mse_max_samples, - layer_filter=getattr(args, "activation_mse_layer_filter", None), - save_dir=mse_save_dir, - ) - print("\n--- Phase 1: Collecting pre-quantization activations ---") - mse_logger.collect(language_model, mse_data, phase="original") - if calibration_only: language_model = mtq.calibrate( language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop @@ -613,16 +576,6 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) - # Phase 2: Compute MSE against stored pre-quant activations - if getattr(args, "measure_activation_mse", False): - print("\n--- Phase 2: Computing per-layer activation MSE ---") - mse_logger.collect(language_model, mse_data, phase="quantized") - mse_logger.compute_mse() - print(mse_logger.summary()) - if mse_save_dir: - mse_logger.save() - del mse_logger, mse_data - # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) @@ -1192,12 +1145,6 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( - "--export_qdq_weights", - help=("Used for GPTQ weights as is without compressed weights for deployment."), - default=False, - action="store_true", - ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index b762757cb9..674d0596e3 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -46,7 +46,7 @@ ) from modelopt.torch.utils import clear_cuda_cache -from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer +from ..quantization.nn import SequentialQuantizer, TensorQuantizer from .model_config import ( KV_CACHE_FP8, KV_CACHE_INT8, @@ -353,17 +353,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer[0]) quantization_format = get_quantization_format(module) - if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) - if not is_nvfp4_static: - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. @@ -373,10 +371,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( weight_quantizer ) - # Unified method handles both static and dynamic quantizers - return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( - weight_quantizer, + return NVFP4QTensor.get_weights_scaling_factor( weight, + weight_quantizer.block_sizes[-1], weight_scaling_factor_2.to(weight.device), )[0] @@ -410,13 +407,16 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") module_name = f"{type(module).__name__}.{weight_name}" _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: - # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. - # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 - else: - # Unified method handles both static and dynamic quantizers - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + if quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: @@ -799,7 +799,7 @@ def process_layer_quant_config(layer_config_dict): layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} - elif v in ["nvfp4", "nvfp4_static"]: + elif v == "nvfp4": layer_config = { "quant_algo": "NVFP4", "group_size": block_size_value, @@ -1397,18 +1397,6 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False for module in modules: module.weight_quantizer[-1].amax = weight_amax - # Handle NVFP4StaticQuantizer: unify global_amax for fused layers - elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): - global_amax_list = [ - m.weight_quantizer.global_amax - for m in modules - if m.weight_quantizer.global_amax is not None - ] - if global_amax_list: - unified_global_amax = torch.max(torch.stack(global_amax_list)) - for module in modules: - module.weight_quantizer.global_amax = unified_global_amax - elif ( modules[0].weight_quantizer.is_enabled and modules[0].weight_quantizer.amax is not None @@ -1493,22 +1481,6 @@ def get_quant_config( if block_size == 0: block_size = get_weight_block_size(module) - # Static NVFP4 uses pre-computed per-block scales from MSE calibration - if quantization_format == QUANTIZATION_NVFP4: - weight_quantizer = getattr(module, "weight_quantizer", None) - if weight_quantizer is None: - # Try to get from first weight attribute - for wn in weight_names: - weight_quantizer = getattr( - module, quantizer_attr_names(wn).weight_quantizer, None - ) - if weight_quantizer is not None: - break - if weight_quantizer is not None: - is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) - if is_static: - quantization_format = "nvfp4_static" - # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 92c196e151..ee230ef948 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,11 +52,7 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import ( - NVFP4StaticQuantizer, - SequentialQuantizer, - TensorQuantizer, -) +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -548,7 +544,6 @@ def _export_quantized_weight( weight, _ = maybe_transpose_expert_weight_dimensions( weight, is_bmm_expert_weight=is_bmm_expert_weight ) - weight_scale = NVFP4QTensor.get_weights_scaling_factor( weight, block_size=block_size, @@ -556,7 +551,7 @@ def _export_quantized_weight( )[0] quantized_weight = to_quantized_weight( - weight.to(torch.bfloat16), + weight.to(dtype), weight_scale, quantization_format, weight_scale_2, @@ -573,7 +568,7 @@ def _export_quantized_weight( ) quantized_weight = to_quantized_weight( - weight.to(torch.bfloat16), + weight.to(dtype), weight_scale, quantization_format, weight_scale_2, diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 757b844fb1..87dbf30bb5 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -19,7 +19,6 @@ from . import mode, plugins, utils # Add methods to mtq namespace -from .activation_mse import ActivationMSELogger, collect_activations, measure_activation_mse from .compress import * from .config import * from .conversion import * diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 100c749234..7b390ef0f2 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,9 +15,6 @@ """Calibration utilities.""" -import contextlib -import datetime -import json import math import os import warnings @@ -1942,7 +1939,6 @@ def hessian_forward(self, input, *args, **kwargs): for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") - # Phase 3: Update weights using computed Hessians (same as gptq_lite) print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): From 7b24cd3980c08571c908a8adf5757e48e7069cd0 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:24:55 +0000 Subject: [PATCH 21/48] cleanup Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/activation_mse.py | 787 ------------------ modelopt/torch/quantization/config.py | 166 +--- 2 files changed, 1 insertion(+), 952 deletions(-) delete mode 100644 modelopt/torch/quantization/activation_mse.py diff --git a/modelopt/torch/quantization/activation_mse.py b/modelopt/torch/quantization/activation_mse.py deleted file mode 100644 index df90c84a3a..0000000000 --- a/modelopt/torch/quantization/activation_mse.py +++ /dev/null @@ -1,787 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Per-layer activation MSE measurement for quantization analysis. - -This module provides utilities to measure per-linear-layer MSE between a model's -activations before and after quantization. Inspired by FP-Quant's two-phase approach: - -- **Phase 1** (before quantization): ``collect_activations()`` runs the model on - calibration data and stores per-layer outputs in CPU RAM. -- **Phase 2** (after quantization): ``measure_activation_mse()`` runs the quantized - model on the same data and computes MSE on-the-fly against the stored Phase 1 - outputs. Only running scalar accumulators are kept -- no second set of tensors - is stored. - -Typical usage in hf_ptq.py:: - - # Phase 1: before quantization - orig_acts = mtq.collect_activations(model, mse_dataloader, max_samples=16) - - # Quantize - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - - # Phase 2: after quantization -- computes MSE incrementally - mse = mtq.measure_activation_mse(model, mse_dataloader, orig_acts, max_samples=16) -""" - -import contextlib -import fnmatch -import hashlib -import os -from collections.abc import Iterable -from datetime import datetime - -import torch -import torch.nn as nn -import torch.nn.functional as F -from tqdm import tqdm - -from modelopt.torch.utils.network import get_decoder_layers - -__all__ = ["ActivationMSELogger", "collect_activations", "measure_activation_mse"] - - -def _tensor_from_output(out) -> torch.Tensor: - """Extract a single tensor from a layer's output (handles tuple returns).""" - if isinstance(out, torch.Tensor): - return out.detach() - return out[0].detach() - - -def _is_linear(module: nn.Module) -> bool: - """Check if a module is a linear layer (covers both nn.Linear and quantized linear).""" - return isinstance(module, nn.Linear) - - -def _matches_filter(name: str, layer_filter: str | None) -> bool: - """Check if a layer name matches the optional filter pattern (fnmatch-style).""" - if layer_filter is None: - return True - return fnmatch.fnmatch(name, layer_filter) - - -def _discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers within decoder blocks of the model. - - Uses get_decoder_layers() to find transformer blocks, then finds all linear - submodules within those blocks. Falls back to all linear layers in the model - if decoder blocks cannot be identified. - - Args: - model: The model to inspect. - layer_filter: Optional fnmatch pattern to select specific layers - (e.g., ``"*self_attn*"``). - - Returns: - Dict mapping full module path -> module reference. - """ - decoder_layers = get_decoder_layers(model) - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - # Build a reverse lookup: module id -> full name in model - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if _is_linear(sub_mod): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # Fallback: scan all modules - for name, module in model.named_modules(): - if _is_linear(module): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model.""" - if isinstance(batch, dict): - model(**batch) - elif isinstance(batch, (list, tuple)): - model(*batch) - else: - model(batch) - - -@torch.no_grad() -def collect_activations( - model: nn.Module, - dataloader: Iterable, - max_samples: int | None = None, - layer_filter: str | None = None, -) -> dict[str, list[torch.Tensor]]: - """Collect per-linear-layer output activations into CPU memory (Phase 1). - - Registers forward hooks on linear layers within the model's decoder blocks, - runs calibration data through the model, and returns captured per-layer outputs. - - Args: - model: The model to collect activations from (typically pre-quantization). - dataloader: An iterable yielding batches (dicts with ``input_ids``, etc.). - Use batch_size=1 to minimize memory. - max_samples: Maximum number of batches to process. ``None`` means all. - layer_filter: Optional fnmatch pattern to restrict which layers are - collected (e.g., ``"*self_attn*"``). ``None`` means all linear layers - inside decoder blocks. - - Returns: - Dict mapping layer name to a list of output tensors (one per batch, on CPU). - """ - was_training = model.training - model.eval() - - # Discover target linear layers - targets = _discover_target_layers(model, layer_filter) - if not targets: - raise ValueError( - f"No linear layers found matching the given filter. layer_filter={layer_filter!r}" - ) - - print(f"Collecting activations for {len(targets)} layers...") - - # Storage: {layer_name: [tensor_per_batch, ...]} - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - # Register hooks - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(dataloader, desc="Collecting activations", leave=False): - if max_samples is not None and n_batches >= max_samples: - break - - captured.clear() - _run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - print(f"Collected {n_batches} samples across {len(targets)} layers") - return saved - - -@torch.no_grad() -def measure_activation_mse( - model: nn.Module, - dataloader: Iterable, - orig_activations: dict[str, list[torch.Tensor]], - max_samples: int | None = None, - layer_filter: str | None = None, -) -> dict[str, float]: - """Compute per-layer MSE between stored and live activations (Phase 2). - - Runs the (quantized) model on calibration data and computes MSE on-the-fly - against the pre-quantization activations stored by :func:`collect_activations`. - - Only scalar accumulators (sum of squared errors and element count) are kept - per layer -- no second set of activation tensors is stored. - - The MSE for each layer is computed as:: - - MSE = sum_over_all_elements((orig - quant) ^ 2) / total_elements - - Args: - model: The quantized model to measure. - dataloader: Same dataloader used for :func:`collect_activations` - (must yield batches in the same order). - orig_activations: Output of :func:`collect_activations` -- dict mapping - layer name to a list of pre-quantization output tensors. - max_samples: Maximum number of batches to process (should match Phase 1). - layer_filter: Optional fnmatch pattern (should match Phase 1). - - Returns: - Dict mapping layer name to its MSE value. - """ - was_training = model.training - model.eval() - - # Discover target layers on the (now-quantized) model - targets = _discover_target_layers(model, layer_filter) - - # Only measure layers that exist in both the model and orig_activations - common_keys = sorted(set(targets.keys()) & set(orig_activations.keys())) - if not common_keys: - raise ValueError( - "No matching layers between the quantized model and stored activations. " - "Ensure the same layer_filter is used for both phases." - ) - - skipped = set(orig_activations.keys()) - set(targets.keys()) - if skipped: - print(f"Warning: {len(skipped)} layers in orig_activations not found in model (skipped)") - - print(f"Computing activation MSE for {len(common_keys)} layers...") - - # Scalar accumulators - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - captured: dict[str, torch.Tensor] = {} - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - # Register hooks only on common layers - hooks = [targets[name].register_forward_hook(_make_hook(name)) for name in common_keys] - - try: - batch_idx = 0 - for batch in tqdm(dataloader, desc="Computing activation MSE", leave=False): - if max_samples is not None and batch_idx >= max_samples: - break - - captured.clear() - _run_batch(model, batch) - - for name in common_keys: - if name not in captured: - continue - if batch_idx >= len(orig_activations.get(name, [])): - continue - - o = orig_activations[name][batch_idx].float() - q = captured[name].float() - - if o.shape != q.shape: - print( - f"Warning: shape mismatch for {name} batch {batch_idx}: " - f"{o.shape} vs {q.shape}, skipping" - ) - continue - - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - batch_idx += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - mse = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in common_keys - } - - return mse - - -# --------------------------------------------------------------------------- -# Portable ActivationMSELogger class -# --------------------------------------------------------------------------- - - -def _portable_discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers in decoder blocks with a portable fallback chain. - - Strategy: - 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). - 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). - 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. - - Within each set of decoder blocks the function collects every ``nn.Linear`` - sub-module and optionally filters by *layer_filter* (fnmatch pattern). - """ - decoder_layers = None - - # 1. Try modelopt helper (may not exist when file is copied elsewhere) - with contextlib.suppress(Exception): - decoder_layers = get_decoder_layers(model) - - # 2. Try common HF / other patterns - if decoder_layers is None: - for attr_chain in ( - ("model", "layers"), - ("decoder", "layers"), - ("transformer", "h"), - ("backbone", "layers"), - ): - obj = model - try: - for attr in attr_chain: - obj = getattr(obj, attr) - if isinstance(obj, nn.ModuleList): - decoder_layers = obj - break - except AttributeError: - continue - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if isinstance(sub_mod, nn.Linear): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # 3. Fallback: all linear layers - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -class ActivationMSELogger: - """Portable activation MSE logger for comparing original vs quantized models. - - Works with both: - - - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` - or ``[B, seq_len]``, consumed via ``model(tensor)``. - - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): - ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. - - Guarantees same samples are used for both phases via SHA-256 hashing of - input tensors. Supports saving / loading all activations to disk for - later cross-codebase comparison. - - Example (ModelOpt -- DataLoader with dict batches):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model, dataloader, phase="original") - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - mse_logger.collect(model, dataloader, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - - Example (FP-Quant -- List[Tensor]):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model_orig, calibration_data, phase="original") - mse_logger.collect(model_quant, calibration_data, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - """ - - def __init__( - self, - max_samples: int = 16, - layer_filter: str | None = None, - save_dir: str | None = None, - ): - """Initialize the ActivationMSELogger. - - Args: - max_samples: Maximum number of calibration batches to process per phase. - layer_filter: Optional glob pattern to restrict which layers are tracked. - save_dir: Optional directory path for persisting activation data to disk. - """ - self.max_samples = max_samples - self.layer_filter = layer_filter - self.save_dir = save_dir - - # Per-phase state - self.original_activations: dict[str, list[torch.Tensor]] = {} - self.quantized_activations: dict[str, list[torch.Tensor]] = {} - self.input_hashes: list[str] = [] # hashes for "original" phase - self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase - - # Computed after both phases - self.mse_results: dict[str, float] | None = None - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def collect( - self, - model: nn.Module, - data: Iterable, - phase: str, - target_modules: dict[str, nn.Module] | None = None, - ) -> None: - """Collect per-linear-layer output activations for a given phase. - - Args: - model: The model to run (original or quantized). - data: An iterable of batches. Each batch can be: - - - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). - - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). - - ``list`` / ``tuple`` of tensors. - phase: ``"original"`` or ``"quantized"``. - target_modules: Optional explicit mapping of ``{name: nn.Module}`` - to attach hooks to. If *None*, layers are auto-discovered - via decoder-block scanning. - """ - if phase not in ("original", "quantized"): - raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") - - was_training = model.training - model.eval() - - # ----- layer discovery ----- - targets = ( - target_modules - if target_modules is not None - else (_portable_discover_target_layers(model, self.layer_filter)) - ) - if not targets: - raise ValueError( - "No linear layers found. Provide target_modules explicitly or " - f"check layer_filter={self.layer_filter!r}." - ) - - print( - f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " - f"max_samples={self.max_samples}" - ) - - # ----- storage ----- - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - hashes: list[str] = [] - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): - if self.max_samples is not None and n_batches >= self.max_samples: - break - - captured.clear() - self._run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - hashes.append(self._hash_batch(batch)) - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - # ----- store results on self ----- - if phase == "original": - self.original_activations = saved - self.input_hashes = hashes - else: - self.quantized_activations = saved - self.quant_input_hashes = hashes - # Verify sample consistency - if self.input_hashes: - self._verify_hashes() - - # Invalidate any previous MSE since we have new activations - self.mse_results = None - - print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") - - def compute_mse(self) -> dict[str, float]: - """Compute per-layer MSE between original and quantized activations. - - Returns: - Dict mapping layer name to its MSE value. - - Raises: - ValueError: If either phase has not been collected yet. - """ - if not self.original_activations: - raise ValueError( - "No original activations collected. Call collect(..., phase='original') first." - ) - if not self.quantized_activations: - raise ValueError( - "No quantized activations collected. Call collect(..., phase='quantized') first." - ) - - common_keys = sorted( - set(self.original_activations.keys()) & set(self.quantized_activations.keys()) - ) - if not common_keys: - raise ValueError( - "No matching layer names between original and quantized activations. " - "Ensure the same model architecture / layer_filter is used for both phases." - ) - - orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) - quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) - if orig_only: - print( - f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" - ) - if quant_only: - print( - f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" - ) - - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - for name in common_keys: - orig_list = self.original_activations[name] - quant_list = self.quantized_activations[name] - n = min(len(orig_list), len(quant_list)) - for i in range(n): - o = orig_list[i].float() - q = quant_list[i].float() - if o.shape != q.shape: - print( - f"[ActivationMSELogger] Warning: shape mismatch for {name} " - f"batch {i}: {o.shape} vs {q.shape}, skipping" - ) - continue - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - self.mse_results = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") - for key in common_keys - } - return self.mse_results - - def save(self, path: str | None = None) -> str: - """Save all state (activations, hashes, MSE) to disk via ``torch.save``. - - Args: - path: Explicit file path. If *None*, a timestamped file is created - inside ``self.save_dir`` (which must be set). - - Returns: - The path where the file was saved. - """ - if path is None: - if self.save_dir is None: - raise ValueError("Provide a path or set save_dir in the constructor.") - os.makedirs(self.save_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") - - payload = { - "max_samples": self.max_samples, - "layer_filter": self.layer_filter, - "input_hashes": self.input_hashes, - "quant_input_hashes": self.quant_input_hashes, - "original_activations": self.original_activations, - "quantized_activations": self.quantized_activations, - "mse": self.mse_results, - } - torch.save(payload, path) - print(f"[ActivationMSELogger] Saved to {path}") - return path - - @classmethod - def load(cls, path: str) -> "ActivationMSELogger": - """Load a previously saved ``ActivationMSELogger`` from disk. - - Args: - path: Path to the ``.pt`` file created by :meth:`save`. - - Returns: - A new ``ActivationMSELogger`` instance with restored state. - """ - payload = torch.load(path, map_location="cpu", weights_only=False) - logger = cls( - max_samples=payload.get("max_samples", 16), - layer_filter=payload.get("layer_filter"), - ) - logger.original_activations = payload.get("original_activations", {}) - logger.quantized_activations = payload.get("quantized_activations", {}) - logger.input_hashes = payload.get("input_hashes", []) - logger.quant_input_hashes = payload.get("quant_input_hashes", []) - logger.mse_results = payload.get("mse") - print(f"[ActivationMSELogger] Loaded from {path}") - return logger - - def summary(self) -> str: - """Return a formatted string summarising per-layer MSE results. - - Computes MSE first if not already done. - """ - if self.mse_results is None: - self.compute_mse() - assert self.mse_results is not None - - lines = ["Per-layer activation MSE (original vs quantized):"] - lines.extend( - f" {key}: {self.mse_results[key]:.6e}" for key in sorted(self.mse_results.keys()) - ) - return "\n".join(lines) - - # ------------------------------------------------------------------ - # Pre-materialized MSE data (cross-run / cross-codebase safety) - # ------------------------------------------------------------------ - - @staticmethod - def materialize_data( - data: Iterable, - path: str, - max_samples: int | None = None, - ) -> list[torch.Tensor]: - """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. - - Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a - single ``input_ids`` CPU tensor before saving. The resulting file is a - plain ``List[Tensor]`` that can be loaded in **any** codebase and passed - straight to :meth:`collect`. - - If *path* already exists it is **not** overwritten -- call - :meth:`load_data` instead. - - Args: - data: Iterable of batches (DataLoader, List[Tensor], etc.). - path: Destination ``.pt`` file path. - max_samples: How many batches to keep. ``None`` means all. - - Returns: - The materialised list of CPU tensors (same object that was saved). - """ - samples: list[torch.Tensor] = [] - for batch in data: - if max_samples is not None and len(samples) >= max_samples: - break - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - samples.append(t.cpu()) - - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - torch.save(samples, path) - print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") - return samples - - @staticmethod - def load_data(path: str) -> list[torch.Tensor]: - """Load a previously materialised MSE input set. - - Args: - path: Path to the ``.pt`` file created by :meth:`materialize_data`. - - Returns: - ``List[Tensor]`` of input batches (on CPU). - """ - samples = torch.load(path, map_location="cpu", weights_only=True) - print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") - return samples - - # ------------------------------------------------------------------ - # Static / private helpers - # ------------------------------------------------------------------ - - @staticmethod - def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model (handles Tensor, dict, list/tuple). - - Automatically moves inputs to the model's device so that CPU-stored - materialized data works transparently with a CUDA model. - """ - device = next(model.parameters()).device - if isinstance(batch, dict): - batch = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } - model(**batch) - elif isinstance(batch, torch.Tensor): - model(batch.to(device)) - elif isinstance(batch, (list, tuple)): - batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) - model(*batch) - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - - @staticmethod - def _hash_batch(batch) -> str: - """Compute SHA-256 hash of the primary input tensor in *batch*. - - - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). - - ``Tensor`` -> hashes the tensor directly. - - ``list/tuple`` -> hashes the first element. - """ - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] if batch else None - else: - return "" - - if t is None or not isinstance(t, torch.Tensor): - return "" - return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() - - def _verify_hashes(self) -> None: - """Compare input hashes between original and quantized phases.""" - n = min(len(self.input_hashes), len(self.quant_input_hashes)) - mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) - if mismatches: - print( - f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " - f"different input hashes between original and quantized phases. " - f"The same data may not have been used for both phases!" - ) - else: - print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 777406d18e..04f50e0767 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -156,22 +156,12 @@ "*mlp.gate.*": {"enable": False}, # Skip the MOE router "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d "*output_layer*": {"enable": False}, "output.*": {"enable": False}, "default": {"enable": False}, } -super_disabled_quantizer_cfg = { - "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE - "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE - "*q_proj*": {"enable": False}, # Skip QKV Linear - "*k_proj*": {"enable": False}, # Skip QKV Linear - "*v_proj*": {"enable": False}, # Skip QKV Linear - "*o_proj*": {"enable": False}, # Skip Output Linear - "*mtp*": {"enable": False}, # Skip MTP layers -} - _mamba_moe_disabled_quantizer_cfg = { "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE @@ -182,53 +172,6 @@ "*o_proj*": {"enable": False}, # Skip QKV Output Projection } -SUPER_NVFP4_CONSERVATIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - **super_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": "max", -} - -SUPER_NVFP4_CONSERVATIVE_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - **super_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - INT8_DEFAULT_CFG = { "quant_cfg": { @@ -328,113 +271,6 @@ "algorithm": "max", } -INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "max", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_DYNAMIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - INT4_AWQ_CFG = { "quant_cfg": { "*weight_quantizer": { From b17b9176f4549a3478a119d884d0eab1c798a20b Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:31:18 +0000 Subject: [PATCH 22/48] removed stray prints Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7b390ef0f2..e45e7c3bd1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -359,7 +359,6 @@ def mse_calibrate( if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator - print("mse_calibrate: Replacing calibrator with NVFP4MSECalibrator") module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, @@ -646,7 +645,6 @@ def quant_func(x, amax, quantizer=weight_quantizer): error_func = helper.get_error_func() if fp8_scale_sweep and is_nvfp4_static: - print("local_hessian_calibrate: Replacing calibrator with NVFP4MSECalibrator") weight_quantizer._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, From 8ff897697f595b74f77f3406cbeaa7a4b202e58e Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:04:14 +0000 Subject: [PATCH 23/48] fix rebase issues Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 55 ++++++++++++++----- modelopt/torch/export/unified_export_hf.py | 7 ++- modelopt/torch/quantization/config.py | 3 - modelopt/torch/quantization/mode.py | 2 - modelopt/torch/quantization/model_calib.py | 22 -------- .../nn/modules/tensor_quantizer.py | 13 +---- .../torch/quantization/triton/__init__.py | 4 -- 7 files changed, 49 insertions(+), 57 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 674d0596e3..4ceb51cd2c 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -46,7 +46,7 @@ ) from modelopt.torch.utils import clear_cuda_cache -from ..quantization.nn import SequentialQuantizer, TensorQuantizer +from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer from .model_config import ( KV_CACHE_FP8, KV_CACHE_INT8, @@ -353,6 +353,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer[0]) quantization_format = get_quantization_format(module) + if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, @@ -371,9 +372,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( weight_quantizer ) - return NVFP4QTensor.get_weights_scaling_factor( + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + weight_quantizer, weight, - weight_quantizer.block_sizes[-1], weight_scaling_factor_2.to(weight.device), )[0] @@ -407,16 +409,13 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") module_name = f"{type(module).__name__}.{weight_name}" _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format in [ - QUANTIZATION_NVFP4, - QUANTIZATION_NVFP4_AWQ, - QUANTIZATION_NVFP4_SVDQUANT, - ]: - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) - elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: - # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. - # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 + else: + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: @@ -799,7 +798,7 @@ def process_layer_quant_config(layer_config_dict): layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} - elif v == "nvfp4": + elif v in ["nvfp4", "nvfp4_static"]: layer_config = { "quant_algo": "NVFP4", "group_size": block_size_value, @@ -1397,6 +1396,18 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False for module in modules: module.weight_quantizer[-1].amax = weight_amax + # Handle NVFP4StaticQuantizer: unify global_amax for fused layers + elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): + global_amax_list = [ + m.weight_quantizer.global_amax + for m in modules + if m.weight_quantizer.global_amax is not None + ] + if global_amax_list: + unified_global_amax = torch.max(torch.stack(global_amax_list)) + for module in modules: + module.weight_quantizer.global_amax = unified_global_amax + elif ( modules[0].weight_quantizer.is_enabled and modules[0].weight_quantizer.amax is not None @@ -1481,6 +1492,22 @@ def get_quant_config( if block_size == 0: block_size = get_weight_block_size(module) + # Static NVFP4 uses pre-computed per-block scales from MSE calibration + if quantization_format == QUANTIZATION_NVFP4: + weight_quantizer = getattr(module, "weight_quantizer", None) + if weight_quantizer is None: + # Try to get from first weight attribute + for wn in weight_names: + weight_quantizer = getattr( + module, quantizer_attr_names(wn).weight_quantizer, None + ) + if weight_quantizer is not None: + break + if weight_quantizer is not None: + is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) + if is_static: + quantization_format = "nvfp4_static" + # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ee230ef948..14a12bcdf3 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,7 +52,11 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer +from modelopt.torch.quantization.nn import ( + NVFP4StaticQuantizer, + SequentialQuantizer, + TensorQuantizer, +) from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -544,6 +548,7 @@ def _export_quantized_weight( weight, _ = maybe_transpose_expert_weight_dimensions( weight, is_bmm_expert_weight=is_bmm_expert_weight ) + weight_scale = NVFP4QTensor.get_weights_scaling_factor( weight, block_size=block_size, diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 04f50e0767..1643c42b36 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -454,9 +454,6 @@ def _nvfp4_selective_quant_cfg( }, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "mse", diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index efc66ffa94..88e93bb770 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -255,8 +255,6 @@ def wrapped_calib_func( else: # Direct calibration (existing behavior) func(model, forward_loop=forward_loop, **kwargs) - else: - raise ValueError(f"No calibration function provided for method: {method}") # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e45e7c3bd1..c3e1c993bd 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -132,28 +132,6 @@ def max_calibrate( for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax() - elif hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - - for name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration - initial_amax = module._amax.clone().detach() - - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - - if is_nvfp4_static: - # Compute and set global_amax - global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if not distributed_sync: return diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 4317c58609..ec2c3cfc55 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1331,19 +1331,10 @@ def global_amax(self, value): def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: - # Ensure amax/global_amax are on the same device as inputs. - # After from_pretrained with device_map, quantizer buffers may remain - # on CPU while model weights/activations are on GPU. - amax = self.amax - if amax.device != inputs.device: - amax = amax.to(inputs.device) - global_amax = self.global_amax - if global_amax is not None and global_amax.device != inputs.device: - global_amax = global_amax.to(inputs.device) return static_blockwise_fp4_fake_quant( inputs, - amax, - global_amax, # Can be None, will be computed internally + self.amax, + self.global_amax, # Can be None, will be computed internally True, # quantize_block_scales inputs.dtype, self._pass_through_bwd, diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index 6e8d4dba11..def70e5914 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -34,10 +34,6 @@ from .fp4_kernel import * from .fp8_kernel import * - # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) - if torch.cuda.get_device_capability() >= (8, 9): - from .fp4_kernel_hopper import * - # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): from .fp4_kernel_hopper import * From 5815ce86ece3ce0808938e2b5c97e51f4591f318 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:06:18 +0000 Subject: [PATCH 24/48] minor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 1643c42b36..bde4d4525a 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -162,7 +162,6 @@ "default": {"enable": False}, } - _mamba_moe_disabled_quantizer_cfg = { "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE @@ -172,7 +171,6 @@ "*o_proj*": {"enable": False}, # Skip QKV Output Projection } - INT8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, @@ -271,6 +269,7 @@ "algorithm": "max", } + INT4_AWQ_CFG = { "quant_cfg": { "*weight_quantizer": { From b1f1434d874fe38786a3f0fb0b57dbdb407bb6d0 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:12:00 +0000 Subject: [PATCH 25/48] tested e2e on qwen Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 34 ++++++++++ modelopt/torch/quantization/config.py | 20 ++++++ modelopt/torch/quantization/mode.py | 4 +- modelopt/torch/quantization/model_calib.py | 73 +++++++++++++++++++--- 4 files changed, 120 insertions(+), 11 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 43adfcea7c..8c1274dc88 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,7 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -108,6 +109,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, } @@ -925,6 +927,7 @@ def quantize_main( else: # mono quantization +<<<<<<< HEAD if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) @@ -932,6 +935,26 @@ def quantize_main( f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" ) quant_cfg = recipe.ptq_cfg +======= + assert ( + args.qformat + in [ + "int8_wo", + "int4_awq", + "fp8", + "nvfp4", + "nvfp4_awq", + "nvfp4_mse", + "nvfp4_gptq", + "w4a8_awq", + "fp8_pb_wo", + "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", + "mxfp8", + ] + or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES + ), f"Plain quantization format {args.qformat} not supported for HF export path" +>>>>>>> 6b8812d6 (tested e2e on qwen) else: assert len(args.qformat.split(",")) == 1, ( @@ -1003,6 +1026,11 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) + + if args.eval_perplexity and tokenizer is not None: + print("Evaluating Wikitext-2 perplexity...") + evaluate_perplexity(language_model, tokenizer, seq_len=args.calib_seq) + export_quantized( args, full_model, @@ -1161,6 +1189,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--eval_perplexity", + help="Evaluate Wikitext-2 perplexity after quantization.", + default=False, + action="store_true", + ) parser.add_argument( "--low_memory_mode", help=( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index bde4d4525a..2b96e35e03 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -476,6 +476,25 @@ def _nvfp4_selective_quant_cfg( }, } +NVFP4_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": {"method": "gptq", "use_sequential": True}, +} + MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { "quant_cfg": { "*weight_quantizer": _nvfp4_quantizer, @@ -632,6 +651,7 @@ def _nvfp4_selective_quant_cfg( "NVFP4_AWQ_FULL_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", + "NVFP4_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 88e93bb770..df48c72c29 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -242,8 +242,8 @@ def wrapped_calib_func( if sequential: if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max"], ( - f"Sequential calibration currently only supports max calibration, got {method}" + assert method in ["max", "gptq"], ( + f"Sequential calibration currently only supports max and gptq calibration, got {method}" ) # Wrap with sequential processing sequential_calibrate( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c3e1c993bd..d27151dcbb 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1848,19 +1848,62 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") +def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted + + @torch.no_grad() def gptq( layer: nn.Module, - inputs: list[tuple[tuple, dict]], forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, **kwargs, ): - """GPTQ quantization - a GPTQ variant.""" - # Set weight amax and activation amax'es for the current layer using max_calibrate + """GPTQ quantization - a GPTQ variant. + + Args: + layer: A single decoder layer to quantize. + forward_loop: Callable that replays calibration inputs through the layer. + Provided by ``sequential_calibrate`` which captures per-layer activations. + percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). + block_size: Block size for GPTQ weight update. + """ + import time + + total_start = time.time() + + # Set weight amax and activation amax for the current layer using max_calibrate max_calibrate(layer, forward_loop=forward_loop) + # Promote NVFP4 static quantizers so they use the two-level scaling path + n_promoted = _promote_nvfp4_static_quantizers(layer) + if n_promoted: + print_rank_0(f"Promoted {n_promoted} quantizer(s) to NVFP4StaticQuantizer") + # Dictionary to store hessian matrices for all linear layers in this decoder hessian_state = {} @@ -1904,18 +1947,20 @@ def hessian_forward(self, input, *args, **kwargs): bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") patched_modules.append(module) - # Run forward passes with the provided inputs to collect Hessians - print_rank_0( - f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." - ) - for args, kwargs_input in inputs: - layer(*args, **kwargs_input) + # Run forward passes to collect Hessians + hessian_start = time.time() + print_rank_0(f"Computing Hessians for {len(tensor_mapping)} linear layers...") + forward_loop(layer) # Unpatch forwards for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: @@ -1927,3 +1972,13 @@ def hessian_forward(self, input, *args, **kwargs): # Free memory del hessian_state[module.name] torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) From df6b18285752a0ec377f1cae1ae582f76543ca2c Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:20:44 +0000 Subject: [PATCH 26/48] removed perplexity eval Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 33 +-------------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 8c1274dc88..cc0e9beda5 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,7 +24,6 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module -from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -927,7 +926,6 @@ def quantize_main( else: # mono quantization -<<<<<<< HEAD if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) @@ -935,26 +933,6 @@ def quantize_main( f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" ) quant_cfg = recipe.ptq_cfg -======= - assert ( - args.qformat - in [ - "int8_wo", - "int4_awq", - "fp8", - "nvfp4", - "nvfp4_awq", - "nvfp4_mse", - "nvfp4_gptq", - "w4a8_awq", - "fp8_pb_wo", - "w4a8_mxfp4_fp8", - "nvfp4_mlp_only", - "mxfp8", - ] - or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES - ), f"Plain quantization format {args.qformat} not supported for HF export path" ->>>>>>> 6b8812d6 (tested e2e on qwen) else: assert len(args.qformat.split(",")) == 1, ( @@ -1027,10 +1005,6 @@ def quantize_main( first_text_speech_dataset, ) - if args.eval_perplexity and tokenizer is not None: - print("Evaluating Wikitext-2 perplexity...") - evaluate_perplexity(language_model, tokenizer, seq_len=args.calib_seq) - export_quantized( args, full_model, @@ -1189,12 +1163,7 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( - "--eval_perplexity", - help="Evaluate Wikitext-2 perplexity after quantization.", - default=False, - action="store_true", - ) + parser.add_argument( "--low_memory_mode", help=( From 75a08fea0b03154e1d2298fa77ad3c94c9d6ccae Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 23:39:59 +0000 Subject: [PATCH 27/48] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index cc0e9beda5..4d036fb23d 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,7 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -681,6 +682,9 @@ def export_quantized( "They will be set at deployment time." ) + if getattr(args, "eval_perplexity", False) and tokenizer is not None: + evaluate_perplexity(full_model, tokenizer, seq_len=2048) + # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) @@ -1222,6 +1226,12 @@ def parse_args() -> argparse.Namespace: "Does not impact non-MOE models." ), ) + parser.add_argument( + "--eval_perplexity", + action=argparse.BooleanOptionalAction, + default=False, + help="Evaluate Wikitext-2 perplexity after quantization (before export).", + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): From 9e58a6fba347bccebc961eb0e0e45924427e18bc Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:26:47 +0000 Subject: [PATCH 28/48] revert later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 123 +++++++++++++++++++++++- modelopt/torch/quantization/__init__.py | 8 +- modelopt/torch/quantization/config.py | 16 +++ 3 files changed, 144 insertions(+), 3 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 4d036fb23d..8297a977b5 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,7 +24,6 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module -from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -63,6 +62,11 @@ ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.metrics import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.utils.dataset_utils import ( @@ -99,6 +103,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, + "nvfp4_wo": mtq.NVFP4_WEIGHT_ONLY_CFG, "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -109,6 +114,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + "nvfp4_wo_gptq": mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG, "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, @@ -683,7 +689,10 @@ def export_quantized( ) if getattr(args, "eval_perplexity", False) and tokenizer is not None: - evaluate_perplexity(full_model, tokenizer, seq_len=2048) + seq_len = getattr(args, "eval_perplexity_seq_len", 2048) + eval_data = get_wikitext2(tokenizer, seq_len) + ppl = compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization @@ -916,6 +925,64 @@ def quantize_main( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) + # Collect original (unquantized) activations before quantization modifies the model + mse_logger = None + if getattr(args, "measure_activation_mse", False): + n_mse = getattr(args, "activation_mse_max_samples", 16) + mse_save_dir = getattr(args, "activation_mse_save_dir", None) + mse_input_path = getattr(args, "activation_mse_input_path", None) + + # Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader + mse_data = None + if mse_input_path is not None: + if mse_input_path.endswith(".json"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .json file: {mse_input_path}") + texts = ActivationMSELogger.load_raw_text(mse_input_path) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + else: + assert tokenizer is not None, ( + "--activation_mse_input_path with .json requires a tokenizer to decode" + ) + print(f"Creating MSE input data .json file: {mse_input_path}") + texts = ActivationMSELogger.materialize_raw_text( + calib_dataloader, + mse_input_path, + tokenizer=tokenizer, + max_samples=n_mse, + ) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + elif mse_input_path.endswith(".pt"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.load_data(mse_input_path) + else: + print(f"Creating MSE input data .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.materialize_data( + calib_dataloader, + mse_input_path, + max_samples=n_mse, + ) + else: + raise ValueError( + f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}" + ) + + if mse_data is None: + mse_data = calib_dataloader + + mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) + print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") + mse_logger.collect(language_model, mse_data, phase="original") + if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( "Auto quantization needs multiple quantization format." @@ -1009,6 +1076,22 @@ def quantize_main( first_text_speech_dataset, ) + if mse_logger is not None: + import gc + + print("Collecting quantized activations for MSE...") + mse_logger.collect(language_model, mse_data, phase="quantized") + + mse_logger.compute_mse() + print(mse_logger.summary()) + + if getattr(args, "activation_mse_save_dir", None): + mse_logger.save() + + del mse_logger, mse_data + gc.collect() + torch.cuda.empty_cache() + export_quantized( args, full_model, @@ -1232,6 +1315,42 @@ def parse_args() -> argparse.Namespace: default=False, help="Evaluate Wikitext-2 perplexity after quantization (before export).", ) + parser.add_argument( + "--eval_perplexity_seq_len", + type=int, + default=2048, + help="Sequence length for perplexity evaluation (default: 2048).", + ) + parser.add_argument( + "--measure_activation_mse", + action=argparse.BooleanOptionalAction, + default=False, + help="Measure per-layer activation MSE (original vs quantized) after quantization.", + ) + parser.add_argument( + "--activation_mse_max_samples", + type=int, + default=16, + help="Max calibration samples for activation MSE (default: 16).", + ) + parser.add_argument( + "--activation_mse_save_dir", + type=str, + default=None, + help="Directory to save activation MSE results. If not set, results are only printed.", + ) + parser.add_argument( + "--activation_mse_input_path", + type=str, + default=None, + help=( + "Path to frozen MSE input data. Supports two formats:\n" + " .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes " + "with the current model's tokenizer; if not, decodes calibration data to text and saves.\n" + " .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; " + "if not, materializes from calibration data and saves." + ), + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 87dbf30bb5..d471e55823 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -16,12 +16,18 @@ """Quantization package.""" # Initialize mode and plugins -from . import mode, plugins, utils +from . import metrics, mode, plugins, utils # Add methods to mtq namespace from .compress import * from .config import * from .conversion import * +from .metrics import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, + measure_per_layer_activation_mse, +) from .model_quant import * from .nn.modules.quant_module import QuantModuleRegistry from .utils import update_quant_cfg_with_kv_cache_quant diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 2b96e35e03..8c1ee07065 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -476,6 +476,20 @@ def _nvfp4_selective_quant_cfg( }, } +NVFP4_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": {"method": "gptq", "use_sequential": True}, +} + NVFP4_GPTQ_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -652,6 +666,8 @@ def _nvfp4_selective_quant_cfg( "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "NVFP4_GPTQ_CFG", + "NVFP4_WEIGHT_ONLY_CFG", + "NVFP4_WEIGHT_ONLY_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", From 16086c78f93d84b01b09d79d93869d7595fffff5 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:32:45 +0000 Subject: [PATCH 29/48] minor update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/utils/network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b07ca570c4..b54332375b 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,7 +46,6 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", - "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", From 9b47e770726f12fb81fcbbcc5eefa1ab0400f2fd Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 18 Mar 2026 06:35:56 +0000 Subject: [PATCH 30/48] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 142 +++++++++++++++++++-- 1 file changed, 129 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d27151dcbb..79c184b27f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1609,6 +1609,103 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv +def _build_column_qdq(quantizer, weight_shape): + """Build a fast column-wise quantize-dequantize function for integer quantizers. + + Instead of calling the full TensorQuantizer on the entire weight matrix (which + quantizes all elements) and extracting one column, this returns a closure that + quantizes only a single column using the quantizer's pre-computed amax/scales. + + Since max_calibrate fixes the amax before GPTQ weight updates, quantizing a + single column with the same fixed scale gives bit-identical results to + quantizing the full matrix and extracting that column. + + Args: + quantizer: The weight TensorQuantizer (already calibrated). + weight_shape: Shape of the weight tensor (out_features, in_features). + + Returns: + Tuple of (column_qdq_fn, supported) where: + - column_qdq_fn(column, col_idx) -> qdq_column (if supported) + - supported: True if column-wise qdq is available, False to fall back. + """ + # Unsupported: NVFP4 (two-level FP4 scaling), FP quantization (num_bits is a tuple) + if isinstance(quantizer, NVFP4StaticQuantizer): + return None, False + if isinstance(quantizer._num_bits, tuple): + return None, False + + # Unsupported: pre_quant_scale (SmoothQuant) or rotation transforms mix columns + if getattr(quantizer, "pre_quant_scale", None) is not None: + return None, False + if getattr(quantizer, "rotate_is_enabled", False): + return None, False + + # Need calibrated amax + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + return None, False + + num_bits = quantizer._num_bits + unsigned = getattr(quantizer, "_unsigned", False) + narrow_range = getattr(quantizer, "_narrow_range", False) + max_bound = (2 ** (num_bits - 1 + int(unsigned))) - 1 + min_bound = -max_bound + int(narrow_range) + + amax = quantizer._amax.float() + out_features, in_features = weight_shape + + # Determine quantization geometry from block_sizes + block_sizes = quantizer.block_sizes + group_size = None + if block_sizes is not None: + # Skip dynamic block quantization + if block_sizes.get("type", "static") == "dynamic": + return None, False + group_size = block_sizes.get(-1, None) or block_sizes.get(len(weight_shape) - 1, None) + + if group_size is not None and group_size > 0: + # Per-group block quantization along last dim. + # After _setup_for_blockquant, weight is reshaped to (-1, group_size) with axis=(0,). + # amax shape: (out_features * n_groups, 1) where n_groups = in_features // group_size. + if in_features % group_size != 0: + return None, False # Padding case — fall back + + n_groups = in_features // group_size + + try: + # Reshape amax to (out_features, n_groups) for O(1) group lookup + amax_2d = amax.reshape(out_features, n_groups) + except RuntimeError: + return None, False + + def _column_qdq_group( + col, col_idx, _a=amax_2d, _mx=max_bound, _mn=min_bound, _gs=group_size + ): + col_scale = _mx / _a[:, col_idx // _gs].clamp(min=1e-12) + return torch.clamp(torch.round(col * col_scale), _mn, _mx) / col_scale + + return _column_qdq_group, True + + # Per-channel (axis != None) or per-tensor (axis == None) + axis = quantizer.axis + if axis is not None: + # Per-channel: amax has shape (out_features, 1) or similar + col_scale = max_bound / amax.reshape(-1).clamp(min=1e-12) + + def _column_qdq_channel(col, col_idx, _s=col_scale, _mx=max_bound, _mn=min_bound): + return torch.clamp(torch.round(col * _s), _mn, _mx) / _s + + return _column_qdq_channel, True + + # Per-tensor: single scalar scale + scalar_scale = max_bound / amax.clamp(min=1e-12).item() + + def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bound): + return torch.clamp(torch.round(col * _s), _mn, _mx) / _s + + return _column_qdq_tensor, True + + def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. @@ -1625,22 +1722,41 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) + # Try to build fast column-wise qdq (avoids quantizing the full matrix per column) + col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape) + # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) n_cols = block_end - block_start - wblk = weight.clone() - errs = torch.zeros_like(wblk[:, block_start:block_end]) h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] - for i in range(n_cols): - w_ci = wblk[:, block_start + i] - d = h_inv_cho_blk[i, i] - qdq = module.weight_quantizer(wblk) - weight[:, block_start + i] = qdq[:, block_start + i] - err = (w_ci - qdq[:, block_start + i]) / d - wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err + if col_qdq_supported: + # Fast path: clone only the block columns, quantize only per-column + wblk = weight[:, block_start:block_end].clone() + errs = torch.zeros_like(wblk) + + for i in range(n_cols): + w_ci = wblk[:, i] + d = h_inv_cho_blk[i, i] + qdq_col = col_qdq_fn(w_ci, block_start + i) + weight[:, block_start + i] = qdq_col + err = (w_ci - qdq_col) / d + wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + else: + # Fallback: original full-matrix quantization path + wblk = weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = module.weight_quantizer(wblk) + weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err # Propagate errors to remaining weights weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) @@ -1844,7 +1960,7 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() - + print_rank_0("Sequential calibration completed") @@ -1969,9 +2085,9 @@ def hessian_forward(self, input, *args, **kwargs): blockwise_weight_update( module, hessian, block_size, percdamp, n_samples=state["n_samples"] ) - # Free memory del hessian_state[module.name] - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() torch.cuda.synchronize() if torch.cuda.is_available() else None weight_update_time = time.time() - weight_update_start From 4ec24335e4ccae25cefa58283b9b7c29b1473b01 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 18 Mar 2026 23:30:36 +0000 Subject: [PATCH 31/48] gptq faster Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 108 +++++++--- .../quantization/triton/gptq_fused_kernel.py | 189 ++++++++++++++++++ tests/gpu/torch/quantization/test_gptq.py | 93 ++++++++- 3 files changed, 365 insertions(+), 25 deletions(-) create mode 100644 modelopt/torch/quantization/triton/gptq_fused_kernel.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 79c184b27f..1e8a94b3db 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1545,7 +1545,7 @@ def _print_relative_mse_error( delta = q - w mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") + print_rank_0(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") def update_hessian(input, hessian, n_samples): @@ -1604,7 +1604,7 @@ def prepare_hessian_inverse(h, weight, percdamp): h = torch.cholesky_inverse(torch.linalg.cholesky(h)) h_inv = torch.linalg.cholesky(h, upper=True) except (RuntimeError, torch.linalg.LinAlgError): - print("Warning: Hessian is not positive definite, using identity matrix") + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) return h_inv @@ -1706,37 +1706,104 @@ def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bou return _column_qdq_tensor, True +def _can_use_fused_gptq(quantizer) -> bool: + """Check whether the fused Triton GPTQ kernel can be used for *quantizer*.""" + if not isinstance(quantizer, NVFP4StaticQuantizer): + return False + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + return False + from modelopt.torch.quantization.triton import IS_AVAILABLE as _TRITON_OK + + return _TRITON_OK + + def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. + Dispatches to one of three internal paths depending on quantizer type: + + 1. **Fused Triton** — for :class:`NVFP4StaticQuantizer` when Triton is + available. Runs the entire column loop in a single GPU kernel per + block (~130x faster than the unfused path on Blackwell GPUs). + 2. **Column-QDQ** — for integer quantizers whose scale geometry allows + single-column fake-quant via :func:`_build_column_qdq`. + 3. **Full-matrix fallback** — calls the quantizer on the full weight matrix + each column (slowest, but always correct). + Args: - module: Neural network module with weight and weight_quantizer - H: Hessian matrix (d x d) - block_size: Size of blocks to process at once - percdamp: Damping percentage for Hessian diagonal - n_samples: Number of Hessian samples for logging (optional) + module: Neural network module with ``weight`` and ``weight_quantizer``. + h: Hessian matrix of shape ``(d, d)``. + block_size: Number of columns processed per block. + percdamp: Damping as a fraction of the mean Hessian diagonal. + n_samples: Number of Hessian samples (used only for logging). """ weight = module.weight.data.float().clone() - _, num_cols = weight.shape + num_rows, num_cols = weight.shape - # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) - # Try to build fast column-wise qdq (avoids quantizing the full matrix per column) - col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape) + quantizer = module.weight_quantizer + if _can_use_fused_gptq(quantizer): + _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size) + else: + col_qdq_fn, col_qdq_supported = _build_column_qdq(quantizer, weight.shape) + _blockwise_weight_update_unfused( + weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported + ) + + _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) + module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) + + +def _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size): + """Fused Triton path for NVFP4: one kernel launch per block.""" + from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block + + group_size = quantizer.block_sizes.get(-1, None) or quantizer.block_sizes.get(1, None) + num_groups = math.ceil(num_cols / group_size) + amax_grouped = quantizer._amax.float().reshape(num_rows, num_groups).contiguous() + global_amax = quantizer.global_amax.float() - # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) - n_cols = block_end - block_start + n_cols_blk = block_end - block_start + + w_block = weight[:, block_start:block_end].clone().contiguous() + h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end].contiguous() + + qw_block, err_block = gptq_fused_block( + w_block, + amax_grouped, + global_amax, + h_inv_cho_blk, + group_size, + block_start, + n_cols_blk, + ) + + weight[:, block_start:block_end] = qw_block + if block_end < num_cols: + weight[:, block_end:].addmm_( + err_block[:, :n_cols_blk], + h_inv[block_start:block_end, block_end:], + alpha=-1, + ) + + +def _blockwise_weight_update_unfused( + weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported +): + """Column-QDQ or full-matrix fallback for non-NVFP4 quantizers.""" + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] if col_qdq_supported: - # Fast path: clone only the block columns, quantize only per-column wblk = weight[:, block_start:block_end].clone() errs = torch.zeros_like(wblk) - for i in range(n_cols): + for i in range(n_cols_blk): w_ci = wblk[:, i] d = h_inv_cho_blk[i, i] qdq_col = col_qdq_fn(w_ci, block_start + i) @@ -1745,27 +1812,20 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err else: - # Fallback: original full-matrix quantization path wblk = weight.clone() errs = torch.zeros_like(wblk[:, block_start:block_end]) - for i in range(n_cols): + for i in range(n_cols_blk): w_ci = wblk[:, block_start + i] d = h_inv_cho_blk[i, i] - qdq = module.weight_quantizer(wblk) + qdq = quantizer(wblk) weight[:, block_start + i] = qdq[:, block_start + i] err = (w_ci - qdq[:, block_start + i]) / d wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err - # Propagate errors to remaining weights weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) - # Print relative mse error - _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) - # Update module weights - module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) - def gptq_lite( model: nn.Module, diff --git a/modelopt/torch/quantization/triton/gptq_fused_kernel.py b/modelopt/torch/quantization/triton/gptq_fused_kernel.py new file mode 100644 index 0000000000..21d84713a1 --- /dev/null +++ b/modelopt/torch/quantization/triton/gptq_fused_kernel.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the GPTQ blockwise weight-update inner loop. + +The standard GPTQ inner loop launches ~10-15 CUDA kernels per column +(amax lookup, FP4 quantization, error computation, rank-1 update). +For ``block_size=128`` that is ~1 500 kernel launches per block, each with +~5-10 us of launch overhead dominating actual compute. + +This module fuses the entire inner loop into a **single** Triton kernel per +block. Rows are independent and map to Triton programs; columns are processed +sequentially inside each program so the rank-1 error update is carried forward +without synchronisation. + +Supported quantisation format: **NVFP4 static block quantisation** (two-level +scaling with per-group amax and a global amax). +""" + +import torch +import triton +import triton.language as tl + +__all__ = ["gptq_fused_block"] + +# -- NVFP4 constants used by the kernel ------------------------------------ +# Maximum representable FP4-E2M1 value (1 + 1 + 0.5 = 6.0 when decoded via +# the standard E2M1 table: {0, 0.5, 1, 1.5, 2, 3, 4, 6}). +_FP4_MAX = 6.0 +# FP8-E4M3 has max representable value 448. +_FP8_E4M3_MAX = 448.0 + + +@triton.jit +def _gptq_fused_block_kernel( + w_ptr, # [num_rows, BLOCK_SIZE] working weight block (in-place) + qw_ptr, # [num_rows, BLOCK_SIZE] output: quantized weights + err_ptr, # [num_rows, BLOCK_SIZE] output: quantization errors + amax_ptr, # [num_rows, num_groups] per-group amax, row-major + global_amax_ptr, # scalar float32 on device + hinv_ptr, # [BLOCK_SIZE, BLOCK_SIZE] upper Cholesky of H^{-1} + num_rows, + num_groups, + group_size: tl.constexpr, + block_start, # column offset of this block in the full weight matrix + n_cols, # actual columns in this block (may be < BLOCK_SIZE) + BLOCK_SIZE: tl.constexpr, +): + """One program per row; sequentially quantizes columns, propagating errors.""" + row = tl.program_id(0) + if row >= num_rows: + return + + # Base pointers for this row + w_base = w_ptr + row * BLOCK_SIZE + qw_base = qw_ptr + row * BLOCK_SIZE + err_base = err_ptr + row * BLOCK_SIZE + amax_row_base = amax_ptr + row * num_groups + + # Pre-compute global FP8 scale factors (constant across columns) + global_amax = tl.load(global_amax_ptr).to(tl.float32) + global_scale = global_amax / 6.0 # _FP4_MAX + fp8_inv_scale = tl.where(global_scale > 0.0, 1.0 / (448.0 / global_scale), 0.0) + + j_range = tl.arange(0, BLOCK_SIZE) + + for i in range(BLOCK_SIZE): + wi = tl.load(w_base + i) + + # -- Compute NVFP4 two-level scale for this column's group ----------- + col_idx = block_start + i + group_idx = col_idx // group_size + raw_amax = tl.load(amax_row_base + group_idx).to(tl.float32) + raw_scale = raw_amax / 6.0 # _FP4_MAX + + # FP8-quantize the block scale: scale * fp8_scale -> cast E4M3 -> back + fp8_scale = tl.where(global_scale > 0.0, 448.0 / global_scale, 1.0) + si = (raw_scale * fp8_scale).to(tl.float8e4nv).to(tl.float32) * fp8_inv_scale + + # Guard: replace zero / nan / inf scale with 1.0 + # NOTE: ``si != si`` is the standard NaN check in Triton (no math.isnan). + si_safe = tl.where( + (si == 0.0) | (si != si) | (tl.abs(si) == float("inf")), # noqa: PLR0124 + 1.0, + si, + ) + + # -- FP4-E2M1 fake quantization (nearest-round to 8 levels) ---------- + abs_scaled = tl.abs(wi) / si_safe + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)), + ), + ), + ), + ), + ) + + qi = q_val * si_safe * tl.where(wi >= 0.0, 1.0, -1.0) + tl.store(qw_base + i, qi) + + # -- GPTQ error and rank-1 update ------------------------------------ + di = tl.load(hinv_ptr + i * BLOCK_SIZE + i) + err_i = (wi - qi) / di + tl.store(err_base + i, err_i) + + j_mask = (j_range > i) & (j_range < n_cols) + hinv_row = tl.load(hinv_ptr + i * BLOCK_SIZE + j_range, mask=j_mask, other=0.0) + w_rem = tl.load(w_base + j_range, mask=j_mask, other=0.0) + w_rem = w_rem - err_i * hinv_row + tl.store(w_base + j_range, w_rem, mask=j_mask) + + +def gptq_fused_block( + w_block: torch.Tensor, + amax_grouped: torch.Tensor, + global_amax: torch.Tensor, + h_inv_cho_blk: torch.Tensor, + group_size: int, + block_start: int, + n_cols: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run the GPTQ column loop for one block in a single Triton kernel launch. + + Args: + w_block: Working weight block of shape ``[num_rows, block_size]`` (will be cloned). + amax_grouped: Per-group amax of shape ``[num_rows, num_groups]``. + global_amax: Scalar tensor with the global amax. + h_inv_cho_blk: Upper Cholesky factor of H^{-1}, shape ``[block_size, block_size]``. + group_size: NVFP4 quantization group size (typically 16). + block_start: Column offset of this block in the full weight matrix. + n_cols: Actual number of columns in this block (``<= block_size``). + + Returns: + Tuple of ``(qw_block, err_block)`` each of shape ``[num_rows, block_size]``. + """ + num_rows, block_size = w_block.shape + num_groups = amax_grouped.shape[1] + + w_block = w_block.contiguous() + amax_grouped = amax_grouped.contiguous() + h_inv_cho_blk = h_inv_cho_blk.contiguous() + + qw_block = torch.empty_like(w_block) + err_block = torch.empty_like(w_block) + + grid = (num_rows,) + with torch.cuda.device(w_block.device): + _gptq_fused_block_kernel[grid]( + w_block, + qw_block, + err_block, + amax_grouped, + global_amax, + h_inv_cho_blk, + num_rows, + num_groups, + group_size, + block_start, + n_cols, + BLOCK_SIZE=block_size, + ) + + return qw_block, err_block diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index c47b48b1e2..23bdf6cbff 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,7 +21,14 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.quantization.model_calib import ( + _blockwise_weight_update_fused, + _blockwise_weight_update_unfused, + blockwise_weight_update, + prepare_hessian_inverse, + update_hessian, +) +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader @@ -295,3 +302,87 @@ def test_gptq_e2e_flow(quant_cfg): print( f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" ) + + +@pytest.mark.parametrize("dim", [256, 512]) +def test_fused_vs_unfused_nvfp4(dim): + """Verify that the fused Triton GPTQ kernel produces equivalent results to the unfused path. + + The fused kernel computes NVFP4 quantisation inline using Triton intrinsics, + which can differ slightly from the PyTorch-level quantiser path (different FP + rounding order). On real models (dim >= 4096) the relative MSE difference is + typically < 0.1%; at the smaller dims used here the tolerance is set to 20%. + """ + from modelopt.torch.quantization.model_calib import _promote_nvfp4_static_quantizers + + torch.manual_seed(RAND_SEED) + block_size = min(128, dim) + + # NVFP4_WEIGHT_ONLY_GPTQ_CFG uses *static* blocks, which get promoted to + # NVFP4StaticQuantizer — the prerequisite for the fused Triton path. + quant_cfg = copy.deepcopy(mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG) + quant_cfg["algorithm"] = "max" # calibrate only, don't run GPTQ + + model = torch.nn.Linear(dim, dim, bias=False).to("cuda") + model.name = "test_fused" + original_weight = model.weight.data.clone() + inp = torch.randn(4, 32, dim, device="cuda") + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(inp)) + + # Promote to NVFP4StaticQuantizer (normally done by gptq / sequential_calibrate) + n_promoted = _promote_nvfp4_static_quantizers(model) + assert n_promoted > 0, "Expected at least one quantizer to be promoted" + + quantizer = model.weight_quantizer + assert isinstance(quantizer, NVFP4StaticQuantizer), ( + f"Expected NVFP4StaticQuantizer, got {type(quantizer).__name__}" + ) + + # Restore original weight and compute Hessian + model.weight.data = original_weight.clone() + hessian = torch.zeros(dim, dim, dtype=torch.float32) + n_samples = 0 + hessian, n_samples = update_hessian(inp, hessian, n_samples) + hessian = hessian.to("cuda") + + # --- Run fused path --- + weight_fused = original_weight.float().clone() + num_rows, num_cols = weight_fused.shape + h_inv = prepare_hessian_inverse(hessian, weight_fused, percdamp=0.01) + _blockwise_weight_update_fused(weight_fused, h_inv, quantizer, num_rows, num_cols, block_size) + + # --- Run unfused path --- + weight_unfused = original_weight.float().clone() + h_inv_unfused = prepare_hessian_inverse(hessian, weight_unfused, percdamp=0.01) + _blockwise_weight_update_unfused( + weight_unfused, h_inv_unfused, quantizer, num_cols, block_size, None, False + ) + + # Both paths must produce non-trivial updates + assert not torch.equal(weight_fused, original_weight.float()), ( + "Fused path did not update weights" + ) + assert not torch.equal(weight_unfused, original_weight.float()), ( + "Unfused path did not update weights" + ) + + # Compare Hessian-weighted relative MSE + def _relative_mse(q, w, h): + delta = q - w + return (delta.mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6)).item() + + orig_f = original_weight.float() + mse_fused = _relative_mse(weight_fused, orig_f, hessian) + mse_unfused = _relative_mse(weight_unfused, orig_f, hessian) + + assert mse_fused > 0, "Fused MSE should be positive" + assert mse_unfused > 0, "Unfused MSE should be positive" + + # At small test dimensions, inline Triton FP4 rounding can diverge up to ~15% + # from the PyTorch path. On production-scale layers this drops below 0.1%. + relative_mse_diff = abs(mse_fused - mse_unfused) / max(mse_fused, mse_unfused) + assert relative_mse_diff < 0.20, ( + f"Fused ({mse_fused:.6e}) and unfused ({mse_unfused:.6e}) MSE differ by " + f"{relative_mse_diff:.2%}, expected < 20%" + ) From 2b0af3d1ecf37a27523b983b398c679eb0289352 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 20 Mar 2026 23:08:13 +0000 Subject: [PATCH 32/48] added metrics files, remove later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 3 + .../torch/quantization/metrics/__init__.py | 28 + .../quantization/metrics/activation_mse.py | 831 ++++++++++++++++++ .../torch/quantization/metrics/perplexity.py | 81 ++ modelopt/torch/quantization/model_calib.py | 6 +- 5 files changed, 948 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/quantization/metrics/__init__.py create mode 100644 modelopt/torch/quantization/metrics/activation_mse.py create mode 100644 modelopt/torch/quantization/metrics/perplexity.py diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 58eb676111..c7e00ebc65 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -551,6 +551,9 @@ def get_model( try: hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + if not hasattr(hf_config, "moe_latent_size"): + hf_config.moe_latent_size = None + if is_nemotron_vl(hf_config): print( "Detected Nemotron VL model from config. " diff --git a/modelopt/torch/quantization/metrics/__init__.py b/modelopt/torch/quantization/metrics/__init__.py new file mode 100644 index 0000000000..a1c737c3c0 --- /dev/null +++ b/modelopt/torch/quantization/metrics/__init__.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +"""Metrics for evaluating quantized models.""" + +from .activation_mse import ActivationMSELogger, measure_per_layer_activation_mse +from .perplexity import compute_perplexity, get_wikitext2 + +__all__ = [ + "ActivationMSELogger", + "compute_perplexity", + "get_wikitext2", + "measure_per_layer_activation_mse", +] diff --git a/modelopt/torch/quantization/metrics/activation_mse.py b/modelopt/torch/quantization/metrics/activation_mse.py new file mode 100644 index 0000000000..1b60977ee1 --- /dev/null +++ b/modelopt/torch/quantization/metrics/activation_mse.py @@ -0,0 +1,831 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +# ruff: noqa: D107, D205, PERF401, PLR0124 + +"""Per-layer activation MSE between original (unquantized) and quantized model. + +Includes the portable ``ActivationMSELogger`` class that works across codebases +(FP-Quant List[Tensor] style *and* ModelOpt DataLoader-of-dicts style). + +Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant +""" + +import fnmatch +import gc +import hashlib +import json +import os +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + + +def _get_module(block: nn.Module, name: str) -> nn.Module: + """Get submodule from block by dotted name, e.g. 'self_attn.q_proj'.""" + obj = block + for part in name.split("."): + obj = getattr(obj, part) + return obj + + +def _get_linear_layer_names(block: nn.Module) -> list[str]: + """Collect relative names of linear layers in a transformer block (same as GPTQ).""" + names = [] + for name, layer in block.named_modules(): + if isinstance(layer, nn.Linear): + names.append(name) + return names + + +def _tensor_from_output(out) -> torch.Tensor: + """Extract a single tensor from layer output (handle tuple return).""" + if isinstance(out, torch.Tensor): + return out.detach() + return out[0].detach() + + +def _discover_layer_keys(blocks, layer_names, num_blocks): + """Build list of valid layer keys.""" + keys = [] + for i in range(num_blocks): + for name in layer_names: + try: + _get_module(blocks[i], name) + except AttributeError: + continue + keys.append(f"model.layers.{i}.{name}") + return keys + + +def _collect_outputs( + model: nn.Module, + blocks: nn.ModuleList, + layer_names: list[str], + layer_keys: list[str], + calibration_data: list[torch.Tensor], + device: torch.device | str, + num_blocks: int, + desc: str, +) -> dict[str, list[torch.Tensor]]: + """Run model on calibration data, capture per-layer outputs (moved to CPU).""" + captured: dict[str, torch.Tensor] = {} + saved: dict[str, list[torch.Tensor]] = {k: [] for k in layer_keys} + + def make_hook(key: str): + def hook(_module: nn.Module, _input: tuple, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for i in range(num_blocks): + for name in layer_names: + key = f"model.layers.{i}.{name}" + if key not in saved: + continue + try: + mod = _get_module(blocks[i], name) + except AttributeError: + continue + hooks.append(mod.register_forward_hook(make_hook(key))) + + try: + for sample in tqdm(calibration_data, desc=desc, leave=False): + inp = sample.unsqueeze(0) if sample.dim() == 1 else sample + inp = inp.to(device) + captured.clear() + with torch.no_grad(): + _ = model(inp) + for key in layer_keys: + if key in captured: + saved[key].append(captured[key]) + finally: + for h in hooks: + h.remove() + return saved + + +@torch.no_grad() +def measure_per_layer_activation_mse( + model_orig: nn.Module, + model_quant: nn.Module, + calibration_data: list[torch.Tensor], + device: torch.device | str, + log_wandb: bool = False, + max_samples: int | None = None, +) -> dict[str, float]: + """Measure per-linear-layer MSE between activations of the original (unquantized) + model and the quantized model on the same calibration data. + + Runs each model on GPU one at a time to avoid OOM. + Returns a dict mapping layer key (e.g. "model.layers.0.self_attn.q_proj") to MSE. + """ + if max_samples is not None and max_samples > 0: + calibration_data = calibration_data[:max_samples] + + blocks_quant = model_quant.model.layers + blocks_orig = model_orig.model.layers + num_blocks = len(blocks_quant) + assert len(blocks_orig) == num_blocks + + layer_names = _get_linear_layer_names(blocks_quant[0]) + layer_keys = _discover_layer_keys(blocks_quant, layer_names, num_blocks) + + # --- Phase 1: run quantized model on GPU, save outputs to CPU --- + print(" Phase 1/2: collecting quantized model outputs...") + model_quant.to(device) + quant_outputs = _collect_outputs( + model_quant, + blocks_quant, + layer_names, + layer_keys, + calibration_data, + device, + num_blocks, + desc="Activation MSE (quant)", + ) + # Free GPU for original model + model_quant.cpu() + gc.collect() + torch.cuda.empty_cache() + + # --- Phase 2: run original model on GPU, compute MSE vs stored quant --- + print(" Phase 2/2: collecting original model outputs and computing MSE...") + model_orig.to(device) + + # Instead of storing orig outputs, compute MSE on the fly per sample + sum_sq: dict[str, float] = dict.fromkeys(layer_keys, 0.0) + count: dict[str, int] = dict.fromkeys(layer_keys, 0) + + captured: dict[str, torch.Tensor] = {} + + def make_hook(key: str): + def hook(_module: nn.Module, _input: tuple, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for i in range(num_blocks): + for name in layer_names: + key = f"model.layers.{i}.{name}" + if key not in sum_sq: + continue + try: + mod = _get_module(blocks_orig[i], name) + except AttributeError: + continue + hooks.append(mod.register_forward_hook(make_hook(key))) + + try: + for sample_idx, sample in enumerate( + tqdm(calibration_data, desc="Activation MSE (orig)", leave=False) + ): + inp = sample.unsqueeze(0) if sample.dim() == 1 else sample + inp = inp.to(device) + captured.clear() + _ = model_orig(inp) + for key in layer_keys: + if key not in captured: + continue + if sample_idx >= len(quant_outputs.get(key, [])): + continue + o = captured[key].float() + q = quant_outputs[key][sample_idx].float() + if o.shape != q.shape: + continue + sum_sq[key] += F.mse_loss(o, q, reduction="sum").item() + count[key] += o.numel() + finally: + for h in hooks: + h.remove() + + # Free original model from GPU + model_orig.cpu() + gc.collect() + torch.cuda.empty_cache() + + # Move quantized model back to GPU for downstream usage + model_quant.to(device) + + mse = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in layer_keys + } + + if log_wandb: + try: + import wandb + + for key, val in mse.items(): + if val == val: # skip nan + wandb.log({f"activation_mse/{key}": val}) + except ImportError: + pass + + return mse + + +# --------------------------------------------------------------------------- +# Portable ActivationMSELogger class +# --------------------------------------------------------------------------- + + +def _matches_filter(name: str, layer_filter: str | None) -> bool: + """Check if a layer name matches the optional filter pattern (fnmatch-style).""" + if layer_filter is None: + return True + return fnmatch.fnmatch(name, layer_filter) + + +def _portable_discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers in decoder blocks with a portable fallback chain. + + Strategy: + 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). + 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). + 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. + + Within each set of decoder blocks the function collects every ``nn.Linear`` + sub-module and optionally filters by *layer_filter* (fnmatch pattern). + """ + decoder_layers = None + + # 1. Try modelopt helper + try: + from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector + + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + except Exception: + pass + + # 2. Try common HF / other patterns + if decoder_layers is None: + for attr_chain in ( + ("model", "layers"), + ("decoder", "layers"), + ("transformer", "h"), + ("backbone", "layers"), + ): + obj = model + try: + for attr in attr_chain: + obj = getattr(obj, attr) + if isinstance(obj, nn.ModuleList): + decoder_layers = obj + break + except AttributeError: + continue + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if isinstance(sub_mod, nn.Linear): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # 3. Fallback: all linear layers + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +class ActivationMSELogger: + """Portable activation MSE logger for comparing original vs quantized models. + + Works with both: + + - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` + or ``[B, seq_len]``, consumed via ``model(tensor)``. + - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): + ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. + + Guarantees same samples are used for both phases via SHA-256 hashing of + input tensors. Supports saving / loading all activations to disk for + later cross-codebase comparison. + + Example (FP-Quant -- List[Tensor]):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model_orig, calibration_data, phase="original") + mse_logger.collect(model_quant, calibration_data, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + + Example (ModelOpt -- DataLoader with dict batches):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model, dataloader, phase="original") + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + mse_logger.collect(model, dataloader, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + """ + + def __init__( + self, + max_samples: int = 16, + layer_filter: str | None = None, + save_dir: str | None = None, + ): + self.max_samples = max_samples + self.layer_filter = layer_filter + self.save_dir = save_dir + + # Per-phase state + self.original_activations: dict[str, list[torch.Tensor]] = {} + self.quantized_activations: dict[str, list[torch.Tensor]] = {} + self.input_hashes: list[str] = [] # hashes for "original" phase + self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase + + # Computed after both phases + self.mse_results: dict[str, float] | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def collect( + self, + model: nn.Module, + data, + phase: str, + target_modules: dict[str, nn.Module] | None = None, + ) -> None: + """Collect per-linear-layer output activations for a given phase. + + Args: + model: The model to run (original or quantized). + data: An iterable of batches. Each batch can be: + + - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). + - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). + - ``list`` / ``tuple`` of tensors. + phase: ``"original"`` or ``"quantized"``. + target_modules: Optional explicit mapping of ``{name: nn.Module}`` + to attach hooks to. If *None*, layers are auto-discovered + via decoder-block scanning. + """ + if phase not in ("original", "quantized"): + raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") + + was_training = model.training + model.eval() + + # ----- layer discovery ----- + targets = ( + target_modules + if target_modules is not None + else (_portable_discover_target_layers(model, self.layer_filter)) + ) + if not targets: + raise ValueError( + "No linear layers found. Provide target_modules explicitly or " + f"check layer_filter={self.layer_filter!r}." + ) + + print( + f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " + f"max_samples={self.max_samples}" + ) + + # ----- storage ----- + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + hashes: list[str] = [] + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): + if self.max_samples is not None and n_batches >= self.max_samples: + break + + captured.clear() + self._run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + hashes.append(self._hash_batch(batch)) + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + # ----- store results on self ----- + if phase == "original": + self.original_activations = saved + self.input_hashes = hashes + else: + self.quantized_activations = saved + self.quant_input_hashes = hashes + # Verify sample consistency + if self.input_hashes: + self._verify_hashes() + + # Invalidate any previous MSE since we have new activations + self.mse_results = None + + print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") + + def compute_mse(self) -> dict[str, float]: + """Compute per-layer MSE between original and quantized activations. + + Returns: + Dict mapping layer name to its MSE value. + + Raises: + ValueError: If either phase has not been collected yet. + """ + if not self.original_activations: + raise ValueError( + "No original activations collected. Call collect(..., phase='original') first." + ) + if not self.quantized_activations: + raise ValueError( + "No quantized activations collected. Call collect(..., phase='quantized') first." + ) + + common_keys = sorted( + set(self.original_activations.keys()) & set(self.quantized_activations.keys()) + ) + if not common_keys: + raise ValueError( + "No matching layer names between original and quantized activations. " + "Ensure the same model architecture / layer_filter is used for both phases." + ) + + orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) + quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) + if orig_only: + print( + f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" + ) + if quant_only: + print( + f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" + ) + + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + for name in common_keys: + orig_list = self.original_activations[name] + quant_list = self.quantized_activations[name] + n = min(len(orig_list), len(quant_list)) + for i in range(n): + o = orig_list[i].float() + q = quant_list[i].float() + if o.shape != q.shape: + print( + f"[ActivationMSELogger] Warning: shape mismatch for {name} " + f"batch {i}: {o.shape} vs {q.shape}, skipping" + ) + continue + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + self.mse_results = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") + for key in common_keys + } + return self.mse_results + + def save(self, path: str | None = None) -> str: + """Save all state (activations, hashes, MSE) to disk via ``torch.save``. + + Args: + path: Explicit file path. If *None*, a timestamped file is created + inside ``self.save_dir`` (which must be set). + + Returns: + The path where the file was saved. + """ + if path is None: + if self.save_dir is None: + raise ValueError("Provide a path or set save_dir in the constructor.") + os.makedirs(self.save_dir, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") + + payload = { + "max_samples": self.max_samples, + "layer_filter": self.layer_filter, + "input_hashes": self.input_hashes, + "quant_input_hashes": self.quant_input_hashes, + "original_activations": self.original_activations, + "quantized_activations": self.quantized_activations, + "mse": self.mse_results, + } + torch.save(payload, path) + print(f"[ActivationMSELogger] Saved to {path}") + return path + + @classmethod + def load(cls, path: str) -> "ActivationMSELogger": + """Load a previously saved ``ActivationMSELogger`` from disk. + + Args: + path: Path to the ``.pt`` file created by :meth:`save`. + + Returns: + A new ``ActivationMSELogger`` instance with restored state. + """ + payload = torch.load(path, map_location="cpu", weights_only=False) + logger = cls( + max_samples=payload.get("max_samples", 16), + layer_filter=payload.get("layer_filter"), + ) + logger.original_activations = payload.get("original_activations", {}) + logger.quantized_activations = payload.get("quantized_activations", {}) + logger.input_hashes = payload.get("input_hashes", []) + logger.quant_input_hashes = payload.get("quant_input_hashes", []) + logger.mse_results = payload.get("mse") + print(f"[ActivationMSELogger] Loaded from {path}") + return logger + + def summary(self) -> str: + """Return a formatted string summarising per-layer MSE results. + + Computes MSE first if not already done. + """ + if self.mse_results is None: + self.compute_mse() + assert self.mse_results is not None + + lines = ["Per-layer activation MSE (original vs quantized):"] + for key in sorted(self.mse_results.keys()): + lines.append(f" {key}: {self.mse_results[key]:.6e}") + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Pre-materialized MSE data (cross-run / cross-codebase safety) + # ------------------------------------------------------------------ + + @staticmethod + def materialize_data( + data, + path: str, + max_samples: int | None = None, + ) -> list[torch.Tensor]: + """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. + + Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a + single ``input_ids`` CPU tensor before saving. The resulting file is a + plain ``List[Tensor]`` that can be loaded in **any** codebase and passed + straight to :meth:`collect`. + + If *path* already exists it is **not** overwritten -- call + :meth:`load_data` instead. + + Args: + data: Iterable of batches (DataLoader, List[Tensor], etc.). + path: Destination ``.pt`` file path. + max_samples: How many batches to keep. ``None`` means all. + + Returns: + The materialised list of CPU tensors (same object that was saved). + """ + samples: list[torch.Tensor] = [] + for batch in data: + if max_samples is not None and len(samples) >= max_samples: + break + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + samples.append(t.cpu()) + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + torch.save(samples, path) + print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") + return samples + + @staticmethod + def load_data(path: str) -> list[torch.Tensor]: + """Load a previously materialised MSE input set. + + Args: + path: Path to the ``.pt`` file created by :meth:`materialize_data`. + + Returns: + ``List[Tensor]`` of input batches (on CPU). + """ + samples = torch.load(path, map_location="cpu", weights_only=True) + print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") + return samples + + # ------------------------------------------------------------------ + # Raw-text materialization (cross-model / cross-tokenizer reuse) + # ------------------------------------------------------------------ + + @staticmethod + def materialize_raw_text( + data, + path: str, + tokenizer=None, + max_samples: int | None = None, + ) -> list[str]: + """Save raw text strings to a JSON file for cross-model reuse. + + Extracts text from batches by decoding ``input_ids`` with the provided + *tokenizer*. The saved JSON file can be loaded by any model regardless + of its vocabulary and re-tokenized via :meth:`tokenize_raw_text`. + + Args: + data: Iterable of batches (DataLoader, ``List[Tensor]``, etc.). + path: Destination ``.json`` file path. + tokenizer: A HuggingFace tokenizer with a ``decode`` method. + Required to convert token IDs back to text. + max_samples: How many batches to keep. ``None`` means all. + + Returns: + The list of decoded text strings (same content that was saved). + """ + if tokenizer is None: + raise ValueError( + "tokenizer is required for materialize_raw_text to decode input_ids back to text." + ) + + texts: list[str] = [] + for batch in data: + if max_samples is not None and len(texts) >= max_samples: + break + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + if t.dim() == 1: + t = t.unsqueeze(0) + for row in t: + if max_samples is not None and len(texts) >= max_samples: + break + texts.append(tokenizer.decode(row, skip_special_tokens=True)) + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + payload = {"texts": texts, "max_samples": len(texts)} + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + + print(f"[ActivationMSELogger] Saved {len(texts)} raw text samples -> {path}") + return texts + + @staticmethod + def load_raw_text(path: str) -> list[str]: + """Load raw text strings from a JSON file created by :meth:`materialize_raw_text`. + + Args: + path: Path to the ``.json`` file. + + Returns: + List of raw text strings. + """ + with open(path, encoding="utf-8") as f: + payload = json.load(f) + texts = payload["texts"] + print(f"[ActivationMSELogger] Loaded {len(texts)} raw text samples from {path}") + return texts + + @staticmethod + def tokenize_raw_text( + texts: list[str], + tokenizer, + max_length: int = 2048, + ) -> list[torch.Tensor]: + """Tokenize raw text strings into a ``List[Tensor]`` for :meth:`collect`. + + Each string is independently tokenized and truncated to *max_length*. + Returns one ``[1, seq_len]`` tensor per string — the same format + expected by :meth:`collect` and :func:`compute_perplexity`. + + Args: + texts: List of raw text strings (from :meth:`load_raw_text`). + tokenizer: A HuggingFace tokenizer. + max_length: Maximum token length per sample (default: 2048). + + Returns: + ``List[Tensor]`` of tokenized inputs on CPU. + """ + samples: list[torch.Tensor] = [] + for text in texts: + encoded = tokenizer( + text, + return_tensors="pt", + max_length=max_length, + truncation=True, + add_special_tokens=False, + ) + samples.append(encoded.input_ids.cpu()) + print(f"[ActivationMSELogger] Tokenized {len(samples)} samples (max_length={max_length})") + return samples + + # ------------------------------------------------------------------ + # Static / private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model (handles Tensor, dict, list/tuple). + + Automatically moves inputs to the model's device so that CPU-stored + materialized data works transparently with a CUDA model. + """ + device = next(model.parameters()).device + if isinstance(batch, dict): + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + model(**batch) + elif isinstance(batch, torch.Tensor): + model(batch.to(device)) + elif isinstance(batch, (list, tuple)): + batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) + model(*batch) + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + @staticmethod + def _hash_batch(batch) -> str: + """Compute SHA-256 hash of the primary input tensor in *batch*. + + - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). + - ``Tensor`` -> hashes the tensor directly. + - ``list/tuple`` -> hashes the first element. + """ + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] if batch else None + else: + return "" + + if t is None or not isinstance(t, torch.Tensor): + return "" + return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() + + def _verify_hashes(self) -> None: + """Compare input hashes between original and quantized phases.""" + n = min(len(self.input_hashes), len(self.quant_input_hashes)) + mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) + if mismatches: + print( + f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " + f"different input hashes between original and quantized phases. " + f"The same data may not have been used for both phases!" + ) + else: + print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") diff --git a/modelopt/torch/quantization/metrics/perplexity.py b/modelopt/torch/quantization/metrics/perplexity.py new file mode 100644 index 0000000000..2b592914ae --- /dev/null +++ b/modelopt/torch/quantization/metrics/perplexity.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +# ruff: noqa: D103, PERF401 + +"""Perplexity evaluation for language models. + +Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant +""" + +import torch +import torch.nn.functional as F +from tqdm import trange + + +@torch.no_grad() +def compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange(0, num_samples, batch_size, desc="Computing perplexity", leave=False): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1) + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + +def get_wikitext2(tokenizer, sequence_length: int): + """Load WikiText-2 test set as a list of tokenized sequences for perplexity evaluation. + + Args: + tokenizer: HuggingFace tokenizer. + sequence_length: Length of each evaluation sequence. + + Returns: + List of tensors, each of shape ``[1, sequence_length]``. + """ + from datasets import load_dataset + + test_dataset_raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [] + for i in range(num_test_sequences): + test_loader.append(test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length]) + return test_loader diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 1e8a94b3db..7589ea703a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -2106,8 +2106,12 @@ def gptq( def _make_hessian_forward(module_name): def hessian_forward(self, input, *args, **kwargs): inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp state = hessian_state[module_name] - hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) + hessian, n_samples = update_hessian(hessian_input, state["hessian"], state["n_samples"]) hessian_state[module_name] = {"hessian": hessian, "n_samples": n_samples} self.weight_quantizer.disable() From ee40b4841a63831f57667278fd8a63492cebced7 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:12:21 +0000 Subject: [PATCH 33/48] claude review Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 3 -- modelopt/torch/quantization/config.py | 20 ++++--------- modelopt/torch/quantization/model_calib.py | 21 ++++++++++++-- tests/gpu/torch/quantization/test_gptq.py | 33 +++++++++++++--------- 4 files changed, 43 insertions(+), 34 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index c7e00ebc65..58eb676111 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -551,9 +551,6 @@ def get_model( try: hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) - if not hasattr(hf_config, "moe_latent_size"): - hf_config.moe_latent_size = None - if is_nemotron_vl(hf_config): print( "Detected Nemotron VL model from config. " diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 8c1ee07065..c0bfe62508 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1383,19 +1383,15 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): class GPTQConfig(QuantizeAlgorithmConfig): - """The config for GPTQ lite. - - GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + """The config for GPTQ quantization. - GPTQ lite does not perform sequential quantization of layers. This means that the updated - activations are not used to process the next layer. + GPTQ minimizes the layer-wise quantization error by using second-order (Hessian) information + to perform blockwise weight updates that compensate for rounding loss. Layers are quantized + sequentially so that each layer's Hessian is computed from activations that already reflect + the quantization of preceding layers. The default values are taken from the official GPTQ implementation: https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - - """ method: Literal["gptq"] = ModeloptField("gptq") @@ -1412,12 +1408,6 @@ class GPTQConfig(QuantizeAlgorithmConfig): description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) - hessian_state_path: str | None = ModeloptField( - default=None, - title="Path to the Hessian state file.", - description="""The path to the Hessian state file. If hessian path exists, we load from - hessian file instead of recomputing them.""", - ) QuantizeQuantCfgType = dict[ diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7589ea703a..4c7ecf86ec 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -17,6 +17,7 @@ import math import os +import time import warnings from collections.abc import Callable from functools import partial @@ -1799,6 +1800,8 @@ def _blockwise_weight_update_unfused( n_cols_blk = block_end - block_start h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] + # wblk is a scratch copy for intra-block error propagation; weight gets + # the final quantized values. Inter-block errors are propagated via addmm_ below. if col_qdq_supported: wblk = weight[:, block_start:block_end].clone() errs = torch.zeros_like(wblk) @@ -2059,7 +2062,21 @@ def gptq( block_size: int = 128, **kwargs, ): - """GPTQ quantization - a GPTQ variant. + """GPTQ quantization for a single decoder layer. + + Invoked by ``sequential_calibrate`` which walks layers one at a time so each + layer sees activations already updated by the quantization of preceding layers. + Within a layer the steps are: + + 1. ``max_calibrate`` to set amax values from the current activations. + 2. Promote eligible quantizers to ``NVFP4StaticQuantizer`` (two-level scaling). + 3. Collect per-linear-layer Hessian matrices via forward hooks. + 4. Blockwise weight updates using the inverse Hessian to compensate for + rounding error (the core GPTQ column-wise update). + + In contrast to ``gptq_lite``, which quantizes all layers in parallel using the + original (unquantized) activations, this method performs sequential calibration + and therefore produces more accurate Hessian estimates. Args: layer: A single decoder layer to quantize. @@ -2068,8 +2085,6 @@ def gptq( percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ - import time - total_start = time.time() # Set weight amax and activation amax for the current layer using max_calibrate diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 23bdf6cbff..7f8db20446 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -55,8 +55,11 @@ def test_update_hessian(): f"Expected hessian shape ({features}, {features}), got {updated_hessian.shape}" ) - # Verify sample count is updated correctly (incremented by batch_size) - assert new_n_samples == batch_size, f"Expected n_samples={batch_size}, got {new_n_samples}" + # Verify sample count is updated correctly (incremented by total tokens = batch * seq_len) + expected_n_samples = batch_size * seq_len + assert new_n_samples == expected_n_samples, ( + f"Expected n_samples={expected_n_samples}, got {new_n_samples}" + ) # Verify hessian is not all zeros after update assert not torch.allclose(updated_hessian, torch.zeros_like(updated_hessian)), ( @@ -79,22 +82,23 @@ def test_update_hessian(): # Manual calculation: # input_flat shape: (features, batch*seq) = (2, 12), all ones - # scaled_input = sqrt(2/6) * input_flat = sqrt(1/3) * ones(2, 12) - # outer_product = scaled_input @ scaled_input.t() = (2/6) * ones(2,12) @ ones(12,2) = [[4,4], [4,4]] - # Note: The scaling factor is (2/n_samples), so with n_samples=6 and 12 tokens: (2/6)*12 = 4 - expected_hessian = torch.ones(features, features, dtype=torch.float32) * 4.0 + # n_samples = batch * seq = 12 (token count after flattening) + # scaled_input = sqrt(2/12) * ones(2, 12) + # outer_product = (2/12) * ones(2,12) @ ones(12,2) = [[2,2], [2,2]] + expected_n_samples = batch_size * seq_len # 12 tokens + expected_hessian = torch.ones(features, features, dtype=torch.float32) * 2.0 assert torch.allclose(updated_hessian, expected_hessian, atol=1e-5), ( f"Expected hessian {expected_hessian}, got {updated_hessian}" ) - assert new_n_samples == batch_size + assert new_n_samples == expected_n_samples # Test 3: Accumulated hessians - verify equivalence # Processing [6,2,2] in one step should equal processing [2,2,2] three times seq_len = 2 features = 2 - # Process in 3 steps of batch_size=2 + # Process in 3 steps of batch_size=2 (4 tokens each, 12 total) hessian_accumulated = torch.zeros(features, features, dtype=torch.float32) n_samples_accumulated = 0 @@ -111,7 +115,8 @@ def test_update_hessian(): assert torch.allclose(hessian_accumulated, expected_hessian, atol=1e-5), ( f"Accumulated hessian should match expected: expected {expected_hessian}, got {hessian_accumulated}" ) - assert n_samples_accumulated == 6, f"Expected n_samples=6, got {n_samples_accumulated}" + # 3 batches * 2 batch_size * 2 seq_len = 12 tokens + assert n_samples_accumulated == 12, f"Expected n_samples=12, got {n_samples_accumulated}" @pytest.mark.parametrize( @@ -146,14 +151,16 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): hessian, n_samples = update_hessian(input, hessian, n_samples) - # Verify n_samples is update using hessian matrix - assert n_samples == input.shape[0], "n_samples should be equal to input.shape[0]" + # Verify n_samples counts total tokens (batch * seq_len) after flattening + expected_tokens = input.shape[0] * input.shape[1] # 2 * 16 = 32 + assert n_samples == expected_tokens, f"n_samples should be {expected_tokens}, got {n_samples}" # Perform another forward pass to update hessian matrix input_2 = torch.randn(3, 16, dim).to("cuda") hessian, n_samples = update_hessian(input_2, hessian, n_samples) - assert n_samples == input.shape[0] + input_2.shape[0], ( - "n_samples should be equal to input.shape[0] + input_2.shape[0]" + expected_tokens_2 = expected_tokens + input_2.shape[0] * input_2.shape[1] # 32 + 48 = 80 + assert n_samples == expected_tokens_2, ( + f"n_samples should be {expected_tokens_2}, got {n_samples}" ) hessian = hessian.to(input.device) From a1751783de4aa8ff171d0582f75e8a71c5f909cf Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:40:02 +0000 Subject: [PATCH 34/48] remove stray files Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 130 --- modelopt/torch/quantization/__init__.py | 8 +- .../torch/quantization/metrics/__init__.py | 28 - .../quantization/metrics/activation_mse.py | 831 ------------------ .../torch/quantization/metrics/perplexity.py | 81 -- 5 files changed, 1 insertion(+), 1077 deletions(-) delete mode 100644 modelopt/torch/quantization/metrics/__init__.py delete mode 100644 modelopt/torch/quantization/metrics/activation_mse.py delete mode 100644 modelopt/torch/quantization/metrics/perplexity.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 8297a977b5..bfe2c861c5 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,7 +15,6 @@ import argparse import copy -import os import random import time import warnings @@ -62,11 +61,6 @@ ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration -from modelopt.torch.quantization.metrics import ( - ActivationMSELogger, - compute_perplexity, - get_wikitext2, -) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.utils.dataset_utils import ( @@ -103,7 +97,6 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, - "nvfp4_wo": mtq.NVFP4_WEIGHT_ONLY_CFG, "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -114,7 +107,6 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - "nvfp4_wo_gptq": mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG, "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, @@ -688,12 +680,6 @@ def export_quantized( "They will be set at deployment time." ) - if getattr(args, "eval_perplexity", False) and tokenizer is not None: - seq_len = getattr(args, "eval_perplexity_seq_len", 2048) - eval_data = get_wikitext2(tokenizer, seq_len) - ppl = compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) @@ -925,64 +911,6 @@ def quantize_main( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) - # Collect original (unquantized) activations before quantization modifies the model - mse_logger = None - if getattr(args, "measure_activation_mse", False): - n_mse = getattr(args, "activation_mse_max_samples", 16) - mse_save_dir = getattr(args, "activation_mse_save_dir", None) - mse_input_path = getattr(args, "activation_mse_input_path", None) - - # Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader - mse_data = None - if mse_input_path is not None: - if mse_input_path.endswith(".json"): - if os.path.isfile(mse_input_path): - print(f"Loading MSE input data from existing .json file: {mse_input_path}") - texts = ActivationMSELogger.load_raw_text(mse_input_path) - mse_data = ActivationMSELogger.tokenize_raw_text( - texts, - tokenizer, - max_length=args.calib_seq, - ) - else: - assert tokenizer is not None, ( - "--activation_mse_input_path with .json requires a tokenizer to decode" - ) - print(f"Creating MSE input data .json file: {mse_input_path}") - texts = ActivationMSELogger.materialize_raw_text( - calib_dataloader, - mse_input_path, - tokenizer=tokenizer, - max_samples=n_mse, - ) - mse_data = ActivationMSELogger.tokenize_raw_text( - texts, - tokenizer, - max_length=args.calib_seq, - ) - elif mse_input_path.endswith(".pt"): - if os.path.isfile(mse_input_path): - print(f"Loading MSE input data from existing .pt file: {mse_input_path}") - mse_data = ActivationMSELogger.load_data(mse_input_path) - else: - print(f"Creating MSE input data .pt file: {mse_input_path}") - mse_data = ActivationMSELogger.materialize_data( - calib_dataloader, - mse_input_path, - max_samples=n_mse, - ) - else: - raise ValueError( - f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}" - ) - - if mse_data is None: - mse_data = calib_dataloader - - mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) - print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") - mse_logger.collect(language_model, mse_data, phase="original") - if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( "Auto quantization needs multiple quantization format." @@ -1076,22 +1004,6 @@ def quantize_main( first_text_speech_dataset, ) - if mse_logger is not None: - import gc - - print("Collecting quantized activations for MSE...") - mse_logger.collect(language_model, mse_data, phase="quantized") - - mse_logger.compute_mse() - print(mse_logger.summary()) - - if getattr(args, "activation_mse_save_dir", None): - mse_logger.save() - - del mse_logger, mse_data - gc.collect() - torch.cuda.empty_cache() - export_quantized( args, full_model, @@ -1309,48 +1221,6 @@ def parse_args() -> argparse.Namespace: "Does not impact non-MOE models." ), ) - parser.add_argument( - "--eval_perplexity", - action=argparse.BooleanOptionalAction, - default=False, - help="Evaluate Wikitext-2 perplexity after quantization (before export).", - ) - parser.add_argument( - "--eval_perplexity_seq_len", - type=int, - default=2048, - help="Sequence length for perplexity evaluation (default: 2048).", - ) - parser.add_argument( - "--measure_activation_mse", - action=argparse.BooleanOptionalAction, - default=False, - help="Measure per-layer activation MSE (original vs quantized) after quantization.", - ) - parser.add_argument( - "--activation_mse_max_samples", - type=int, - default=16, - help="Max calibration samples for activation MSE (default: 16).", - ) - parser.add_argument( - "--activation_mse_save_dir", - type=str, - default=None, - help="Directory to save activation MSE results. If not set, results are only printed.", - ) - parser.add_argument( - "--activation_mse_input_path", - type=str, - default=None, - help=( - "Path to frozen MSE input data. Supports two formats:\n" - " .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes " - "with the current model's tokenizer; if not, decodes calibration data to text and saves.\n" - " .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; " - "if not, materializes from calibration data and saves." - ), - ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index d471e55823..87dbf30bb5 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -16,18 +16,12 @@ """Quantization package.""" # Initialize mode and plugins -from . import metrics, mode, plugins, utils +from . import mode, plugins, utils # Add methods to mtq namespace from .compress import * from .config import * from .conversion import * -from .metrics import ( - ActivationMSELogger, - compute_perplexity, - get_wikitext2, - measure_per_layer_activation_mse, -) from .model_quant import * from .nn.modules.quant_module import QuantModuleRegistry from .utils import update_quant_cfg_with_kv_cache_quant diff --git a/modelopt/torch/quantization/metrics/__init__.py b/modelopt/torch/quantization/metrics/__init__.py deleted file mode 100644 index a1c737c3c0..0000000000 --- a/modelopt/torch/quantization/metrics/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors - -"""Metrics for evaluating quantized models.""" - -from .activation_mse import ActivationMSELogger, measure_per_layer_activation_mse -from .perplexity import compute_perplexity, get_wikitext2 - -__all__ = [ - "ActivationMSELogger", - "compute_perplexity", - "get_wikitext2", - "measure_per_layer_activation_mse", -] diff --git a/modelopt/torch/quantization/metrics/activation_mse.py b/modelopt/torch/quantization/metrics/activation_mse.py deleted file mode 100644 index 1b60977ee1..0000000000 --- a/modelopt/torch/quantization/metrics/activation_mse.py +++ /dev/null @@ -1,831 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors -# ruff: noqa: D107, D205, PERF401, PLR0124 - -"""Per-layer activation MSE between original (unquantized) and quantized model. - -Includes the portable ``ActivationMSELogger`` class that works across codebases -(FP-Quant List[Tensor] style *and* ModelOpt DataLoader-of-dicts style). - -Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant -""" - -import fnmatch -import gc -import hashlib -import json -import os -from datetime import datetime - -import torch -import torch.nn as nn -import torch.nn.functional as F -from tqdm import tqdm - - -def _get_module(block: nn.Module, name: str) -> nn.Module: - """Get submodule from block by dotted name, e.g. 'self_attn.q_proj'.""" - obj = block - for part in name.split("."): - obj = getattr(obj, part) - return obj - - -def _get_linear_layer_names(block: nn.Module) -> list[str]: - """Collect relative names of linear layers in a transformer block (same as GPTQ).""" - names = [] - for name, layer in block.named_modules(): - if isinstance(layer, nn.Linear): - names.append(name) - return names - - -def _tensor_from_output(out) -> torch.Tensor: - """Extract a single tensor from layer output (handle tuple return).""" - if isinstance(out, torch.Tensor): - return out.detach() - return out[0].detach() - - -def _discover_layer_keys(blocks, layer_names, num_blocks): - """Build list of valid layer keys.""" - keys = [] - for i in range(num_blocks): - for name in layer_names: - try: - _get_module(blocks[i], name) - except AttributeError: - continue - keys.append(f"model.layers.{i}.{name}") - return keys - - -def _collect_outputs( - model: nn.Module, - blocks: nn.ModuleList, - layer_names: list[str], - layer_keys: list[str], - calibration_data: list[torch.Tensor], - device: torch.device | str, - num_blocks: int, - desc: str, -) -> dict[str, list[torch.Tensor]]: - """Run model on calibration data, capture per-layer outputs (moved to CPU).""" - captured: dict[str, torch.Tensor] = {} - saved: dict[str, list[torch.Tensor]] = {k: [] for k in layer_keys} - - def make_hook(key: str): - def hook(_module: nn.Module, _input: tuple, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for i in range(num_blocks): - for name in layer_names: - key = f"model.layers.{i}.{name}" - if key not in saved: - continue - try: - mod = _get_module(blocks[i], name) - except AttributeError: - continue - hooks.append(mod.register_forward_hook(make_hook(key))) - - try: - for sample in tqdm(calibration_data, desc=desc, leave=False): - inp = sample.unsqueeze(0) if sample.dim() == 1 else sample - inp = inp.to(device) - captured.clear() - with torch.no_grad(): - _ = model(inp) - for key in layer_keys: - if key in captured: - saved[key].append(captured[key]) - finally: - for h in hooks: - h.remove() - return saved - - -@torch.no_grad() -def measure_per_layer_activation_mse( - model_orig: nn.Module, - model_quant: nn.Module, - calibration_data: list[torch.Tensor], - device: torch.device | str, - log_wandb: bool = False, - max_samples: int | None = None, -) -> dict[str, float]: - """Measure per-linear-layer MSE between activations of the original (unquantized) - model and the quantized model on the same calibration data. - - Runs each model on GPU one at a time to avoid OOM. - Returns a dict mapping layer key (e.g. "model.layers.0.self_attn.q_proj") to MSE. - """ - if max_samples is not None and max_samples > 0: - calibration_data = calibration_data[:max_samples] - - blocks_quant = model_quant.model.layers - blocks_orig = model_orig.model.layers - num_blocks = len(blocks_quant) - assert len(blocks_orig) == num_blocks - - layer_names = _get_linear_layer_names(blocks_quant[0]) - layer_keys = _discover_layer_keys(blocks_quant, layer_names, num_blocks) - - # --- Phase 1: run quantized model on GPU, save outputs to CPU --- - print(" Phase 1/2: collecting quantized model outputs...") - model_quant.to(device) - quant_outputs = _collect_outputs( - model_quant, - blocks_quant, - layer_names, - layer_keys, - calibration_data, - device, - num_blocks, - desc="Activation MSE (quant)", - ) - # Free GPU for original model - model_quant.cpu() - gc.collect() - torch.cuda.empty_cache() - - # --- Phase 2: run original model on GPU, compute MSE vs stored quant --- - print(" Phase 2/2: collecting original model outputs and computing MSE...") - model_orig.to(device) - - # Instead of storing orig outputs, compute MSE on the fly per sample - sum_sq: dict[str, float] = dict.fromkeys(layer_keys, 0.0) - count: dict[str, int] = dict.fromkeys(layer_keys, 0) - - captured: dict[str, torch.Tensor] = {} - - def make_hook(key: str): - def hook(_module: nn.Module, _input: tuple, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for i in range(num_blocks): - for name in layer_names: - key = f"model.layers.{i}.{name}" - if key not in sum_sq: - continue - try: - mod = _get_module(blocks_orig[i], name) - except AttributeError: - continue - hooks.append(mod.register_forward_hook(make_hook(key))) - - try: - for sample_idx, sample in enumerate( - tqdm(calibration_data, desc="Activation MSE (orig)", leave=False) - ): - inp = sample.unsqueeze(0) if sample.dim() == 1 else sample - inp = inp.to(device) - captured.clear() - _ = model_orig(inp) - for key in layer_keys: - if key not in captured: - continue - if sample_idx >= len(quant_outputs.get(key, [])): - continue - o = captured[key].float() - q = quant_outputs[key][sample_idx].float() - if o.shape != q.shape: - continue - sum_sq[key] += F.mse_loss(o, q, reduction="sum").item() - count[key] += o.numel() - finally: - for h in hooks: - h.remove() - - # Free original model from GPU - model_orig.cpu() - gc.collect() - torch.cuda.empty_cache() - - # Move quantized model back to GPU for downstream usage - model_quant.to(device) - - mse = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in layer_keys - } - - if log_wandb: - try: - import wandb - - for key, val in mse.items(): - if val == val: # skip nan - wandb.log({f"activation_mse/{key}": val}) - except ImportError: - pass - - return mse - - -# --------------------------------------------------------------------------- -# Portable ActivationMSELogger class -# --------------------------------------------------------------------------- - - -def _matches_filter(name: str, layer_filter: str | None) -> bool: - """Check if a layer name matches the optional filter pattern (fnmatch-style).""" - if layer_filter is None: - return True - return fnmatch.fnmatch(name, layer_filter) - - -def _portable_discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers in decoder blocks with a portable fallback chain. - - Strategy: - 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). - 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). - 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. - - Within each set of decoder blocks the function collects every ``nn.Linear`` - sub-module and optionally filters by *layer_filter* (fnmatch pattern). - """ - decoder_layers = None - - # 1. Try modelopt helper - try: - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector - - decoder_layers = LayerActivationCollector.get_decoder_layers(model) - except Exception: - pass - - # 2. Try common HF / other patterns - if decoder_layers is None: - for attr_chain in ( - ("model", "layers"), - ("decoder", "layers"), - ("transformer", "h"), - ("backbone", "layers"), - ): - obj = model - try: - for attr in attr_chain: - obj = getattr(obj, attr) - if isinstance(obj, nn.ModuleList): - decoder_layers = obj - break - except AttributeError: - continue - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if isinstance(sub_mod, nn.Linear): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # 3. Fallback: all linear layers - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -class ActivationMSELogger: - """Portable activation MSE logger for comparing original vs quantized models. - - Works with both: - - - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` - or ``[B, seq_len]``, consumed via ``model(tensor)``. - - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): - ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. - - Guarantees same samples are used for both phases via SHA-256 hashing of - input tensors. Supports saving / loading all activations to disk for - later cross-codebase comparison. - - Example (FP-Quant -- List[Tensor]):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model_orig, calibration_data, phase="original") - mse_logger.collect(model_quant, calibration_data, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - - Example (ModelOpt -- DataLoader with dict batches):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model, dataloader, phase="original") - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - mse_logger.collect(model, dataloader, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - """ - - def __init__( - self, - max_samples: int = 16, - layer_filter: str | None = None, - save_dir: str | None = None, - ): - self.max_samples = max_samples - self.layer_filter = layer_filter - self.save_dir = save_dir - - # Per-phase state - self.original_activations: dict[str, list[torch.Tensor]] = {} - self.quantized_activations: dict[str, list[torch.Tensor]] = {} - self.input_hashes: list[str] = [] # hashes for "original" phase - self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase - - # Computed after both phases - self.mse_results: dict[str, float] | None = None - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def collect( - self, - model: nn.Module, - data, - phase: str, - target_modules: dict[str, nn.Module] | None = None, - ) -> None: - """Collect per-linear-layer output activations for a given phase. - - Args: - model: The model to run (original or quantized). - data: An iterable of batches. Each batch can be: - - - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). - - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). - - ``list`` / ``tuple`` of tensors. - phase: ``"original"`` or ``"quantized"``. - target_modules: Optional explicit mapping of ``{name: nn.Module}`` - to attach hooks to. If *None*, layers are auto-discovered - via decoder-block scanning. - """ - if phase not in ("original", "quantized"): - raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") - - was_training = model.training - model.eval() - - # ----- layer discovery ----- - targets = ( - target_modules - if target_modules is not None - else (_portable_discover_target_layers(model, self.layer_filter)) - ) - if not targets: - raise ValueError( - "No linear layers found. Provide target_modules explicitly or " - f"check layer_filter={self.layer_filter!r}." - ) - - print( - f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " - f"max_samples={self.max_samples}" - ) - - # ----- storage ----- - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - hashes: list[str] = [] - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): - if self.max_samples is not None and n_batches >= self.max_samples: - break - - captured.clear() - self._run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - hashes.append(self._hash_batch(batch)) - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - # ----- store results on self ----- - if phase == "original": - self.original_activations = saved - self.input_hashes = hashes - else: - self.quantized_activations = saved - self.quant_input_hashes = hashes - # Verify sample consistency - if self.input_hashes: - self._verify_hashes() - - # Invalidate any previous MSE since we have new activations - self.mse_results = None - - print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") - - def compute_mse(self) -> dict[str, float]: - """Compute per-layer MSE between original and quantized activations. - - Returns: - Dict mapping layer name to its MSE value. - - Raises: - ValueError: If either phase has not been collected yet. - """ - if not self.original_activations: - raise ValueError( - "No original activations collected. Call collect(..., phase='original') first." - ) - if not self.quantized_activations: - raise ValueError( - "No quantized activations collected. Call collect(..., phase='quantized') first." - ) - - common_keys = sorted( - set(self.original_activations.keys()) & set(self.quantized_activations.keys()) - ) - if not common_keys: - raise ValueError( - "No matching layer names between original and quantized activations. " - "Ensure the same model architecture / layer_filter is used for both phases." - ) - - orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) - quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) - if orig_only: - print( - f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" - ) - if quant_only: - print( - f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" - ) - - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - for name in common_keys: - orig_list = self.original_activations[name] - quant_list = self.quantized_activations[name] - n = min(len(orig_list), len(quant_list)) - for i in range(n): - o = orig_list[i].float() - q = quant_list[i].float() - if o.shape != q.shape: - print( - f"[ActivationMSELogger] Warning: shape mismatch for {name} " - f"batch {i}: {o.shape} vs {q.shape}, skipping" - ) - continue - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - self.mse_results = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") - for key in common_keys - } - return self.mse_results - - def save(self, path: str | None = None) -> str: - """Save all state (activations, hashes, MSE) to disk via ``torch.save``. - - Args: - path: Explicit file path. If *None*, a timestamped file is created - inside ``self.save_dir`` (which must be set). - - Returns: - The path where the file was saved. - """ - if path is None: - if self.save_dir is None: - raise ValueError("Provide a path or set save_dir in the constructor.") - os.makedirs(self.save_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") - - payload = { - "max_samples": self.max_samples, - "layer_filter": self.layer_filter, - "input_hashes": self.input_hashes, - "quant_input_hashes": self.quant_input_hashes, - "original_activations": self.original_activations, - "quantized_activations": self.quantized_activations, - "mse": self.mse_results, - } - torch.save(payload, path) - print(f"[ActivationMSELogger] Saved to {path}") - return path - - @classmethod - def load(cls, path: str) -> "ActivationMSELogger": - """Load a previously saved ``ActivationMSELogger`` from disk. - - Args: - path: Path to the ``.pt`` file created by :meth:`save`. - - Returns: - A new ``ActivationMSELogger`` instance with restored state. - """ - payload = torch.load(path, map_location="cpu", weights_only=False) - logger = cls( - max_samples=payload.get("max_samples", 16), - layer_filter=payload.get("layer_filter"), - ) - logger.original_activations = payload.get("original_activations", {}) - logger.quantized_activations = payload.get("quantized_activations", {}) - logger.input_hashes = payload.get("input_hashes", []) - logger.quant_input_hashes = payload.get("quant_input_hashes", []) - logger.mse_results = payload.get("mse") - print(f"[ActivationMSELogger] Loaded from {path}") - return logger - - def summary(self) -> str: - """Return a formatted string summarising per-layer MSE results. - - Computes MSE first if not already done. - """ - if self.mse_results is None: - self.compute_mse() - assert self.mse_results is not None - - lines = ["Per-layer activation MSE (original vs quantized):"] - for key in sorted(self.mse_results.keys()): - lines.append(f" {key}: {self.mse_results[key]:.6e}") - return "\n".join(lines) - - # ------------------------------------------------------------------ - # Pre-materialized MSE data (cross-run / cross-codebase safety) - # ------------------------------------------------------------------ - - @staticmethod - def materialize_data( - data, - path: str, - max_samples: int | None = None, - ) -> list[torch.Tensor]: - """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. - - Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a - single ``input_ids`` CPU tensor before saving. The resulting file is a - plain ``List[Tensor]`` that can be loaded in **any** codebase and passed - straight to :meth:`collect`. - - If *path* already exists it is **not** overwritten -- call - :meth:`load_data` instead. - - Args: - data: Iterable of batches (DataLoader, List[Tensor], etc.). - path: Destination ``.pt`` file path. - max_samples: How many batches to keep. ``None`` means all. - - Returns: - The materialised list of CPU tensors (same object that was saved). - """ - samples: list[torch.Tensor] = [] - for batch in data: - if max_samples is not None and len(samples) >= max_samples: - break - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - samples.append(t.cpu()) - - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - torch.save(samples, path) - print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") - return samples - - @staticmethod - def load_data(path: str) -> list[torch.Tensor]: - """Load a previously materialised MSE input set. - - Args: - path: Path to the ``.pt`` file created by :meth:`materialize_data`. - - Returns: - ``List[Tensor]`` of input batches (on CPU). - """ - samples = torch.load(path, map_location="cpu", weights_only=True) - print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") - return samples - - # ------------------------------------------------------------------ - # Raw-text materialization (cross-model / cross-tokenizer reuse) - # ------------------------------------------------------------------ - - @staticmethod - def materialize_raw_text( - data, - path: str, - tokenizer=None, - max_samples: int | None = None, - ) -> list[str]: - """Save raw text strings to a JSON file for cross-model reuse. - - Extracts text from batches by decoding ``input_ids`` with the provided - *tokenizer*. The saved JSON file can be loaded by any model regardless - of its vocabulary and re-tokenized via :meth:`tokenize_raw_text`. - - Args: - data: Iterable of batches (DataLoader, ``List[Tensor]``, etc.). - path: Destination ``.json`` file path. - tokenizer: A HuggingFace tokenizer with a ``decode`` method. - Required to convert token IDs back to text. - max_samples: How many batches to keep. ``None`` means all. - - Returns: - The list of decoded text strings (same content that was saved). - """ - if tokenizer is None: - raise ValueError( - "tokenizer is required for materialize_raw_text to decode input_ids back to text." - ) - - texts: list[str] = [] - for batch in data: - if max_samples is not None and len(texts) >= max_samples: - break - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - - if t.dim() == 1: - t = t.unsqueeze(0) - for row in t: - if max_samples is not None and len(texts) >= max_samples: - break - texts.append(tokenizer.decode(row, skip_special_tokens=True)) - - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - payload = {"texts": texts, "max_samples": len(texts)} - with open(path, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - - print(f"[ActivationMSELogger] Saved {len(texts)} raw text samples -> {path}") - return texts - - @staticmethod - def load_raw_text(path: str) -> list[str]: - """Load raw text strings from a JSON file created by :meth:`materialize_raw_text`. - - Args: - path: Path to the ``.json`` file. - - Returns: - List of raw text strings. - """ - with open(path, encoding="utf-8") as f: - payload = json.load(f) - texts = payload["texts"] - print(f"[ActivationMSELogger] Loaded {len(texts)} raw text samples from {path}") - return texts - - @staticmethod - def tokenize_raw_text( - texts: list[str], - tokenizer, - max_length: int = 2048, - ) -> list[torch.Tensor]: - """Tokenize raw text strings into a ``List[Tensor]`` for :meth:`collect`. - - Each string is independently tokenized and truncated to *max_length*. - Returns one ``[1, seq_len]`` tensor per string — the same format - expected by :meth:`collect` and :func:`compute_perplexity`. - - Args: - texts: List of raw text strings (from :meth:`load_raw_text`). - tokenizer: A HuggingFace tokenizer. - max_length: Maximum token length per sample (default: 2048). - - Returns: - ``List[Tensor]`` of tokenized inputs on CPU. - """ - samples: list[torch.Tensor] = [] - for text in texts: - encoded = tokenizer( - text, - return_tensors="pt", - max_length=max_length, - truncation=True, - add_special_tokens=False, - ) - samples.append(encoded.input_ids.cpu()) - print(f"[ActivationMSELogger] Tokenized {len(samples)} samples (max_length={max_length})") - return samples - - # ------------------------------------------------------------------ - # Static / private helpers - # ------------------------------------------------------------------ - - @staticmethod - def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model (handles Tensor, dict, list/tuple). - - Automatically moves inputs to the model's device so that CPU-stored - materialized data works transparently with a CUDA model. - """ - device = next(model.parameters()).device - if isinstance(batch, dict): - batch = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } - model(**batch) - elif isinstance(batch, torch.Tensor): - model(batch.to(device)) - elif isinstance(batch, (list, tuple)): - batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) - model(*batch) - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - - @staticmethod - def _hash_batch(batch) -> str: - """Compute SHA-256 hash of the primary input tensor in *batch*. - - - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). - - ``Tensor`` -> hashes the tensor directly. - - ``list/tuple`` -> hashes the first element. - """ - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] if batch else None - else: - return "" - - if t is None or not isinstance(t, torch.Tensor): - return "" - return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() - - def _verify_hashes(self) -> None: - """Compare input hashes between original and quantized phases.""" - n = min(len(self.input_hashes), len(self.quant_input_hashes)) - mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) - if mismatches: - print( - f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " - f"different input hashes between original and quantized phases. " - f"The same data may not have been used for both phases!" - ) - else: - print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") diff --git a/modelopt/torch/quantization/metrics/perplexity.py b/modelopt/torch/quantization/metrics/perplexity.py deleted file mode 100644 index 2b592914ae..0000000000 --- a/modelopt/torch/quantization/metrics/perplexity.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors -# ruff: noqa: D103, PERF401 - -"""Perplexity evaluation for language models. - -Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant -""" - -import torch -import torch.nn.functional as F -from tqdm import trange - - -@torch.no_grad() -def compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange(0, num_samples, batch_size, desc="Computing perplexity", leave=False): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1) - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - -def get_wikitext2(tokenizer, sequence_length: int): - """Load WikiText-2 test set as a list of tokenized sequences for perplexity evaluation. - - Args: - tokenizer: HuggingFace tokenizer. - sequence_length: Length of each evaluation sequence. - - Returns: - List of tensors, each of shape ``[1, sequence_length]``. - """ - from datasets import load_dataset - - test_dataset_raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [] - for i in range(num_test_sequences): - test_loader.append(test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length]) - return test_loader From a94849736408fd1bdf67f6f9ca7974be45da2ce1 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Sun, 22 Mar 2026 02:32:16 +0000 Subject: [PATCH 35/48] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 38 - modelopt/torch/quantization/mode.py | 14 - modelopt/torch/quantization/model_calib.py | 670 +++++------------- .../quantization/triton/gptq_fused_kernel.py | 189 ----- tests/gpu/torch/quantization/test_gptq.py | 104 +-- 5 files changed, 175 insertions(+), 840 deletions(-) delete mode 100644 modelopt/torch/quantization/triton/gptq_fused_kernel.py diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index c0bfe62508..d096f8390e 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1344,44 +1344,6 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): ) -class GPTQLiteConfig(QuantizeAlgorithmConfig): - """The config for GPTQ lite. - - GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. - - GPTQ lite does not perform sequential quantization of layers. This means that the updated - activations are not used to process the next layer. - - The default values are taken from the official GPTQ implementation: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - - - """ - - method: Literal["gptq_lite"] = ModeloptField("gptq_lite") - percdamp: float | None = ModeloptField( - default=0.01, - gt=0.0, - le=1.0, - title="Percentage damping factor.", - description="The percentage of average Hessian diagonal used for damping.", - ) - block_size: int | None = ModeloptField( - default=128, - title="Block size for GPTQ weight update.", - description="""The block size for GPTQ weight update, which must be a multiple of the - group_size used in the quantization.""", - ) - hessian_state_path: str | None = ModeloptField( - default=None, - title="Path to the Hessian state file.", - description="""The path to the Hessian state file. If hessian path exists, we load from - hessian file instead of recomputing them.""", - ) - - class GPTQConfig(QuantizeAlgorithmConfig): """The config for GPTQ quantization. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index df48c72c29..63b3a7c913 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -38,7 +38,6 @@ AWQLiteCalibConfig, CompressConfig, GPTQConfig, - GPTQLiteConfig, LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, @@ -61,7 +60,6 @@ from .model_calib import ( awq, gptq, - gptq_lite, local_hessian_calibrate, max_calibrate, mse_calibrate, @@ -494,18 +492,6 @@ def restore(self) -> RestoreEntrypoint: return restore_svdquant_model -@CalibrateModeRegistry.register_mode -class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor): - """Mode for GPTQ calibration algorithm.""" - - @property - def config_class(self) -> type[QuantizeAlgorithmConfig]: - """Specifies the config class for the mode.""" - return GPTQLiteConfig - - _calib_func = gptq_lite - - @CalibrateModeRegistry.register_mode class GPTQModeDescriptor(BaseCalibrateModeDescriptor): """Mode for GPTQ calibration algorithm.""" diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 4c7ecf86ec..c5b4bc2aab 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,7 +16,6 @@ """Calibration utilities.""" import math -import os import time import warnings from collections.abc import Callable @@ -1520,461 +1519,187 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def _print_relative_mse_error( - q: torch.Tensor, - w: torch.Tensor, - h: torch.Tensor, - module_name: str, - n_samples: int | None = None, -): - """Print relative mean squared error between quantized and original weights. - - Computes the Hessian-weighted relative MSE between quantized and original weights, - providing a measure of quantization quality. This metric is adapted from the GPTQ - repository. - - Args: - q (torch.Tensor): Quantized weight tensor - w (torch.Tensor): Original weight tensor - h (torch.Tensor): Hessian matrix used for weighting the error - module_name (str): Name of the module for logging purposes - n_samples (int | None): Number of Hessian samples (batches) used for this layer - Note: - Implementation adapted from the GPTQ repository: - https://github.com/IST-DASLab/FP-Quant - """ - delta = q - w - mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) - suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" - print_rank_0(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") - - -def update_hessian(input, hessian, n_samples): - """Update hessian matrix with new input samples using incremental formula. - - Args: - input: Input tensor (batch_size, ..., features) - hessian: Current Hessian matrix to update in-place - n_samples: Number of samples already processed - Returns: - Tuple of (updated_hessian, new_sample_count) - """ - # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens - input_flat = input.reshape(-1, input.shape[-1]).t().float() - batch_size = input_flat.shape[1] - - # Incremental averaging: scale down old hessian - hessian *= n_samples / (n_samples + batch_size) - n_samples += batch_size - - # Compute outer product: H += (2/n_samples) * X @ X^T - scaled_input = math.sqrt(2 / n_samples) * input_flat - hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) - - return hessian, n_samples - - -def prepare_hessian_inverse(h, weight, percdamp): - """Prepare inverse Hessian with dead neuron handling and damping. - - Args: - h: Hessian matrix to update - weight: Weight tensor to prepare Hessian for - percdamp: Damping percentage for Hessian diagonal - Returns: - h_inv: Inverse Hessian matrix - Implementation adapted from the FP-Quant repository: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 - """ - h = h.clone() - # Handle dead neurons (zero weight columns) - # Get columns with all zeros in weight - zero_cols = torch.nonzero(weight.eq(0).all(dim=0)).unsqueeze(-1) - - # Zero out entire rows and columns in Hessian for dead neurons - h[zero_cols, :] = 0 - h[:, zero_cols] = 0 - h[zero_cols, zero_cols] = 1 - - # Add damping to diagonal - damp = percdamp * torch.mean(torch.diag(h)) - diag_indices = torch.arange(h.shape[0], device=h.device) - h[diag_indices, diag_indices] += damp - - try: - h = torch.cholesky_inverse(torch.linalg.cholesky(h)) - h_inv = torch.linalg.cholesky(h, upper=True) - except (RuntimeError, torch.linalg.LinAlgError): - print_rank_0("Warning: Hessian is not positive definite, using identity matrix") - h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - return h_inv - - -def _build_column_qdq(quantizer, weight_shape): - """Build a fast column-wise quantize-dequantize function for integer quantizers. +class GPTQHandle: + """Encapsulates per-module GPTQ state and operations. - Instead of calling the full TensorQuantizer on the entire weight matrix (which - quantizes all elements) and extracting one column, this returns a closure that - quantizes only a single column using the quantizer's pre-computed amax/scales. + Owns the Hessian, patches the forward during collection, and contains + the blockwise weight-update logic. - Since max_calibrate fixes the amax before GPTQ weight updates, quantizing a - single column with the same fixed scale gives bit-identical results to - quantizing the full matrix and extracting that column. + Instance attributes set during ``__init__``: + module, name, hessian, n_samples - Args: - quantizer: The weight TensorQuantizer (already calibrated). - weight_shape: Shape of the weight tensor (out_features, in_features). - - Returns: - Tuple of (column_qdq_fn, supported) where: - - column_qdq_fn(column, col_idx) -> qdq_column (if supported) - - supported: True if column-wise qdq is available, False to fall back. + Instance attributes set during ``quantize``: + weight: float working copy of module weights (mutated in-place by update methods) + h_inv: upper-triangular Cholesky factor of the damped inverse Hessian """ - # Unsupported: NVFP4 (two-level FP4 scaling), FP quantization (num_bits is a tuple) - if isinstance(quantizer, NVFP4StaticQuantizer): - return None, False - if isinstance(quantizer._num_bits, tuple): - return None, False - - # Unsupported: pre_quant_scale (SmoothQuant) or rotation transforms mix columns - if getattr(quantizer, "pre_quant_scale", None) is not None: - return None, False - if getattr(quantizer, "rotate_is_enabled", False): - return None, False - - # Need calibrated amax - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - return None, False - - num_bits = quantizer._num_bits - unsigned = getattr(quantizer, "_unsigned", False) - narrow_range = getattr(quantizer, "_narrow_range", False) - max_bound = (2 ** (num_bits - 1 + int(unsigned))) - 1 - min_bound = -max_bound + int(narrow_range) - - amax = quantizer._amax.float() - out_features, in_features = weight_shape - - # Determine quantization geometry from block_sizes - block_sizes = quantizer.block_sizes - group_size = None - if block_sizes is not None: - # Skip dynamic block quantization - if block_sizes.get("type", "static") == "dynamic": - return None, False - group_size = block_sizes.get(-1, None) or block_sizes.get(len(weight_shape) - 1, None) - - if group_size is not None and group_size > 0: - # Per-group block quantization along last dim. - # After _setup_for_blockquant, weight is reshaped to (-1, group_size) with axis=(0,). - # amax shape: (out_features * n_groups, 1) where n_groups = in_features // group_size. - if in_features % group_size != 0: - return None, False # Padding case — fall back - - n_groups = in_features // group_size - - try: - # Reshape amax to (out_features, n_groups) for O(1) group lookup - amax_2d = amax.reshape(out_features, n_groups) - except RuntimeError: - return None, False - - def _column_qdq_group( - col, col_idx, _a=amax_2d, _mx=max_bound, _mn=min_bound, _gs=group_size - ): - col_scale = _mx / _a[:, col_idx // _gs].clamp(min=1e-12) - return torch.clamp(torch.round(col * col_scale), _mn, _mx) / col_scale - - return _column_qdq_group, True - - # Per-channel (axis != None) or per-tensor (axis == None) - axis = quantizer.axis - if axis is not None: - # Per-channel: amax has shape (out_features, 1) or similar - col_scale = max_bound / amax.reshape(-1).clamp(min=1e-12) - - def _column_qdq_channel(col, col_idx, _s=col_scale, _mx=max_bound, _mn=min_bound): - return torch.clamp(torch.round(col * _s), _mn, _mx) / _s - - return _column_qdq_channel, True - # Per-tensor: single scalar scale - scalar_scale = max_bound / amax.clamp(min=1e-12).item() + CACHE_NAME = "_forward_no_gptq_hessian" + + def __init__(self, module, name, offload_to_cpu=False): + self.module = module + self.name = name + in_features = module.weight.shape[-1] + device = module.weight.device + if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + device = "cpu" + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.n_samples = 0 + # Set by quantize(); listed here for documentation. + self.weight: torch.Tensor | None = None + self.h_inv: torch.Tensor | None = None + + def setup(self): + """Patch the module's forward to accumulate Hessian during the collection pass.""" + gptq_handle = self - def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bound): - return torch.clamp(torch.round(col * _s), _mn, _mx) / _s - - return _column_qdq_tensor, True - - -def _can_use_fused_gptq(quantizer) -> bool: - """Check whether the fused Triton GPTQ kernel can be used for *quantizer*.""" - if not isinstance(quantizer, NVFP4StaticQuantizer): - return False - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - return False - from modelopt.torch.quantization.triton import IS_AVAILABLE as _TRITON_OK - - return _TRITON_OK + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp + gptq_handle.hessian, gptq_handle.n_samples = update_hessian( + hessian_input, gptq_handle.hessian, gptq_handle.n_samples + ) + self.weight_quantizer.disable() + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out -def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): - """Update module weights using GPTQ-style blockwise quantization. + bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) - Dispatches to one of three internal paths depending on quantizer type: + def cleanup(self): + """Unpatch the module's forward method.""" + unpatch_forward_method(self.module, self.CACHE_NAME) - 1. **Fused Triton** — for :class:`NVFP4StaticQuantizer` when Triton is - available. Runs the entire column loop in a single GPU kernel per - block (~130x faster than the unfused path on Blackwell GPUs). - 2. **Column-QDQ** — for integer quantizers whose scale geometry allows - single-column fake-quant via :func:`_build_column_qdq`. - 3. **Full-matrix fallback** — calls the quantizer on the full weight matrix - each column (slowest, but always correct). + def quantize(self, block_size, percdamp): + """Run GPTQ blockwise weight update on this module. - Args: - module: Neural network module with ``weight`` and ``weight_quantizer``. - h: Hessian matrix of shape ``(d, d)``. - block_size: Number of columns processed per block. - percdamp: Damping as a fraction of the mean Hessian diagonal. - n_samples: Number of Hessian samples (used only for logging). - """ - weight = module.weight.data.float().clone() - num_rows, num_cols = weight.shape + Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, + logs MSE, and writes the result back to the module. + """ + hessian = self.hessian.to(self.module.weight.device) + self.weight = self.module.weight.data.float().clone() + self._prepare_hessian_inverse(hessian, percdamp) - h_inv = prepare_hessian_inverse(h, weight, percdamp) + self._blockwise_update(block_size) - quantizer = module.weight_quantizer - if _can_use_fused_gptq(quantizer): - _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size) - else: - col_qdq_fn, col_qdq_supported = _build_column_qdq(quantizer, weight.shape) - _blockwise_weight_update_unfused( - weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported + self._print_mse_error(hessian) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype ) - _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) - module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) - + # ------------------------------------------------------------------ + # Quantize helpers — all read from self.module, self.weight, self.h_inv + # ------------------------------------------------------------------ -def _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size): - """Fused Triton path for NVFP4: one kernel launch per block.""" - from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block + def _prepare_hessian_inverse(self, hessian, percdamp): + """Compute damped inverse Hessian and store as ``self.h_inv``. - group_size = quantizer.block_sizes.get(-1, None) or quantizer.block_sizes.get(1, None) - num_groups = math.ceil(num_cols / group_size) - amax_grouped = quantizer._amax.float().reshape(num_rows, num_groups).contiguous() - global_amax = quantizer.global_amax.float() + Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the + Hessian before inversion, matching the FP-Quant reference: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + assert self.weight is not None, "_prepare_hessian_inverse called before quantize()" + h = hessian.clone() + zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - n_cols_blk = block_end - block_start + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 - w_block = weight[:, block_start:block_end].clone().contiguous() - h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end].contiguous() + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp - qw_block, err_block = gptq_fused_block( - w_block, - amax_grouped, - global_amax, - h_inv_cho_blk, - group_size, - block_start, - n_cols_blk, + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + + def _blockwise_update(self, block_size): + """Column-wise GPTQ update using full-matrix QDQ. + + For each column, quantizes the full weight matrix via the quantizer and + extracts the quantized column. This is the standard GPTQ approach. + + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. + """ + assert self.weight is not None and self.h_inv is not None, ( + "_blockwise_update called before _prepare_hessian_inverse()" ) + quantizer = self.module.weight_quantizer + num_cols = self.weight.shape[1] - weight[:, block_start:block_end] = qw_block - if block_end < num_cols: - weight[:, block_end:].addmm_( - err_block[:, :n_cols_blk], - h_inv[block_start:block_end, block_end:], - alpha=-1, - ) + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start + h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] - -def _blockwise_weight_update_unfused( - weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported -): - """Column-QDQ or full-matrix fallback for non-NVFP4 quantizers.""" - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - n_cols_blk = block_end - block_start - h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] - - # wblk is a scratch copy for intra-block error propagation; weight gets - # the final quantized values. Inter-block errors are propagated via addmm_ below. - if col_qdq_supported: - wblk = weight[:, block_start:block_end].clone() - errs = torch.zeros_like(wblk) - - for i in range(n_cols_blk): - w_ci = wblk[:, i] - d = h_inv_cho_blk[i, i] - qdq_col = col_qdq_fn(w_ci, block_start + i) - weight[:, block_start + i] = qdq_col - err = (w_ci - qdq_col) / d - wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err - else: - wblk = weight.clone() + wblk = self.weight.clone() errs = torch.zeros_like(wblk[:, block_start:block_end]) for i in range(n_cols_blk): w_ci = wblk[:, block_start + i] d = h_inv_cho_blk[i, i] qdq = quantizer(wblk) - weight[:, block_start + i] = qdq[:, block_start + i] + self.weight[:, block_start + i] = qdq[:, block_start + i] err = (w_ci - qdq[:, block_start + i]) / d wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err - weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) + self.weight[:, block_end:].addmm_( + errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 + ) + def _print_mse_error(self, hessian): + """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" + w_orig = self.module.weight.float() + delta = self.weight - w_orig + mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) + suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" + print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") -def gptq_lite( - model: nn.Module, - forward_loop: ForwardLoop | None = None, - percdamp: float = 0.01, - block_size: int = 128, - hessian_state_path: str | None = None, -): - """GPTQ-lite quantization - a simplified GPTQ variant. - Key differences from GPTQ: - - Layers are quantized in parallel (not sequentially with updated activations) - - Uses group-wise updates instead of column-wise updates +def update_hessian(input, hessian, n_samples): + """Update hessian matrix with new input samples using incremental formula. Args: - model: Model to be calibrated. - forward_loop: Callable that forwards calibration data through the model. - percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). - block_size: Block size for GPTQ weight update. - hessian_state_path: Path to save/load Hessian state. If None, compute without saving. - If path exists, load from it. If path doesnt exist then save computed hessians to path. - - See :class:`GPTQLiteConfig ` for - details on the remaining arguments. - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. + input: Input tensor (batch_size, ..., features) + hessian: Current Hessian matrix to update in-place + n_samples: Number of samples already processed + Returns: + Tuple of (updated_hessian, new_sample_count) """ - # Dictionary to store hessian matrices: {layer_name: {"hessian": Tensor, "n_samples": int}} - hessian_state = {} - - def initialize_hessian_state(tensor_mapping): - """Initialize hessian state with zeros.""" - for name, (shape, device) in tensor_mapping.items(): - # Use CPU if GPU memory is tight - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": torch.zeros(shape, dtype=torch.float32, device=target_device), - "n_samples": 0, - } - - def load_hessian_state(path, tensor_mapping): - """Load hessian state from file.""" - print_rank_0(f"Loading hessian state from {path}") - loaded_state = torch.load(path, map_location="cpu") - - for name, (shape, device) in tensor_mapping.items(): - if name not in loaded_state: - raise KeyError(f"Layer '{name}' not found in loaded hessian state") - - # Move to appropriate device based on memory - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": loaded_state[name]["hessian"].to(target_device), - "n_samples": loaded_state[name]["n_samples"], - } + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] - print_rank_0(f"Successfully loaded hessian state with {len(hessian_state)} layers") + # Incremental averaging: scale down old hessian + hessian *= n_samples / (n_samples + batch_size) + n_samples += batch_size - def save_hessian_state(path): - """Save hessian state to file.""" - print_rank_0(f"Saving hessian state to {path}") - try: - # Move to CPU for saving - cpu_state = { - name: {"hessian": state["hessian"].cpu(), "n_samples": state["n_samples"]} - for name, state in hessian_state.items() - } + # Compute outer product: H += (2/n_samples) * X @ X^T + scaled_input = math.sqrt(2 / n_samples) * input_flat + hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) - os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) - torch.save(cpu_state, path) - print_rank_0(f"Successfully saved hessian state to {path}") - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") + return hessian, n_samples - def hessian_hook(module, input, output): - """Hook to intercept activations and update hessian matrix.""" - state = hessian_state[module.name] - hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) - hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} - # Phase 1: Collect statistics for quantizers - max_calibrate(model) +def _get_quantized_linear_layers(parent: nn.Module) -> list[tuple[str, nn.Module]]: + """Return (name, module) pairs for all quantized linear layers with enabled weight quantizers. - # Phase 2: Build tensor mapping for all quantized layers - tensor_mapping = {} - for name, module in model.named_modules(): + Also sets ``module.name`` on each returned module for downstream logging. + """ + layers = [] + for name, module in parent.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - in_features = module.weight.shape[-1] - tensor_mapping[name] = ((in_features, in_features), module.weight.device) - module.name = name # Attach name for easy access in hooks - - # Phase 3: Load or compute Hessians - hessian_exists = hessian_state_path is not None and os.path.exists(hessian_state_path) - save_hessians = hessian_state_path is not None and not hessian_exists - - if hessian_exists: - print_rank_0(f"Loading hessian state from {hessian_state_path}") - load_hessian_state(hessian_state_path, tensor_mapping) - else: - if forward_loop is None: - raise ValueError("forward_loop must be provided when computing Hessians") - - # Initialize hessian state - initialize_hessian_state(tensor_mapping) - - # Register hooks to collect activations - handles = [] - for name, module in model.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - handles.append(module.register_forward_hook(hessian_hook)) - - # Run forward loop to compute hessians - print_rank_0("Computing Hessian matrices...") - forward_loop(model) - - for handle in handles: - handle.remove() - - # Save if configured - if save_hessians: - try: - save_hessian_state(hessian_state_path) - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") - - # Phase 4: Update weights using computed Hessians - print_rank_0("Updating weights using GPTQ-lite algorithm...") - - quantized_modules = [ - (name, module) - for name, module in model.named_modules() - if is_quantized_linear(module) and module.weight_quantizer.is_enabled - ] - - # Perform blockwise weight updates - for name, module in tqdm(quantized_modules, desc="Quantizing layers"): - state = hessian_state[module.name] - hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update(module, hessian, block_size, percdamp) - # Delete hessian state to free memory - del hessian_state[module.name] - torch.cuda.empty_cache() - - print_rank_0("GPTQ-lite quantization completed successfully") + module.name = name + layers.append((name, module)) + return layers @torch.no_grad() @@ -2056,17 +1781,22 @@ def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: @torch.no_grad() def gptq( - layer: nn.Module, + model: nn.Module, forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, - **kwargs, ): - """GPTQ quantization for a single decoder layer. + """GPTQ quantization. + + Works in two modes depending on ``use_sequential`` in the config: + + * **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this + function once per decoder layer with updated activations, producing more + accurate Hessian estimates. + * **Non-sequential** (``use_sequential=False``): called once on the full model. + All layers are quantized in parallel from the original activations. - Invoked by ``sequential_calibrate`` which walks layers one at a time so each - layer sees activations already updated by the quantization of preceding layers. - Within a layer the steps are: + Per-module steps: 1. ``max_calibrate`` to set amax values from the current activations. 2. Promote eligible quantizers to ``NVFP4StaticQuantizer`` (two-level scaling). @@ -2074,106 +1804,38 @@ def gptq( 4. Blockwise weight updates using the inverse Hessian to compensate for rounding error (the core GPTQ column-wise update). - In contrast to ``gptq_lite``, which quantizes all layers in parallel using the - original (unquantized) activations, this method performs sequential calibration - and therefore produces more accurate Hessian estimates. - Args: - layer: A single decoder layer to quantize. - forward_loop: Callable that replays calibration inputs through the layer. - Provided by ``sequential_calibrate`` which captures per-layer activations. + model: The module to quantize — either the full model or a single decoder + layer when invoked by ``sequential_calibrate``. + forward_loop: Callable that replays calibration inputs through *model*. percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ total_start = time.time() - # Set weight amax and activation amax for the current layer using max_calibrate - max_calibrate(layer, forward_loop=forward_loop) - - # Promote NVFP4 static quantizers so they use the two-level scaling path - n_promoted = _promote_nvfp4_static_quantizers(layer) - if n_promoted: - print_rank_0(f"Promoted {n_promoted} quantizer(s) to NVFP4StaticQuantizer") + max_calibrate(model, forward_loop=forward_loop) + _promote_nvfp4_static_quantizers(model) - # Dictionary to store hessian matrices for all linear layers in this decoder - hessian_state = {} - - # Phase 1: Build tensor mapping for all quantized linear layers in this decoder layer - tensor_mapping = {} - for name, module in layer.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - in_features = module.weight.shape[-1] - tensor_mapping[name] = ((in_features, in_features), module.weight.device) - module.name = name # Attach name for easy access in hooks - - if not tensor_mapping: - print_rank_0("No quantized linear layers found in decoder layer, skipping GPTQ") + quantized_layers = _get_quantized_linear_layers(model) + if not quantized_layers: + print_rank_0("No quantized linear layers found, skipping GPTQ") return - # Initialize hessian state with zeros - for name, (shape, device) in tensor_mapping.items(): - hessian_state[name] = { - "hessian": torch.zeros(shape, dtype=torch.float32, device=device), - "n_samples": 0, - } - - # Phase 2: Patch forwards to collect Hessians (similar to local_hessian_calibrate) - def _make_hessian_forward(module_name): - def hessian_forward(self, input, *args, **kwargs): - inp = input.to_local() if hasattr(input, "to_local") else input - if self.input_quantizer is not None and self.input_quantizer.is_enabled: - hessian_input = self.input_quantizer(inp) - else: - hessian_input = inp - state = hessian_state[module_name] - hessian, n_samples = update_hessian(hessian_input, state["hessian"], state["n_samples"]) - hessian_state[module_name] = {"hessian": hessian, "n_samples": n_samples} - - self.weight_quantizer.disable() - out = self._forward_no_gptq_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() - return out - - return hessian_forward + gptq_handles = {name: GPTQHandle(m, name, offload_to_cpu=True) for name, m in quantized_layers} + for handle in gptq_handles.values(): + handle.setup() - patched_modules = [] - for name, module in layer.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") - patched_modules.append(module) - - # Run forward passes to collect Hessians - hessian_start = time.time() - print_rank_0(f"Computing Hessians for {len(tensor_mapping)} linear layers...") - forward_loop(layer) - - # Unpatch forwards - for module in patched_modules: - unpatch_forward_method(module, "_forward_no_gptq_hessian") + print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") + forward_loop(model) - torch.cuda.synchronize() if torch.cuda.is_available() else None - hessian_time = time.time() - hessian_start + for handle in gptq_handles.values(): + handle.cleanup() - # Phase 3: Update weights using computed Hessians (same as gptq_lite) - weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") - for name, module in layer.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - state = hessian_state[module.name] - hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update( - module, hessian, block_size, percdamp, n_samples=state["n_samples"] - ) - del hessian_state[module.name] + for handle in gptq_handles.values(): + handle.quantize(block_size, percdamp) + del gptq_handles + if torch.cuda.is_available(): torch.cuda.empty_cache() - - torch.cuda.synchronize() if torch.cuda.is_available() else None - weight_update_time = time.time() - weight_update_start - - total_time = time.time() - total_start - print_rank_0( - f"GPTQ timing - Hessian: {hessian_time:.2f}s, " - f"Weight update: {weight_update_time:.2f}s, " - f"Total: {total_time:.2f}s" - ) + print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s") diff --git a/modelopt/torch/quantization/triton/gptq_fused_kernel.py b/modelopt/torch/quantization/triton/gptq_fused_kernel.py deleted file mode 100644 index 21d84713a1..0000000000 --- a/modelopt/torch/quantization/triton/gptq_fused_kernel.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fused Triton kernel for the GPTQ blockwise weight-update inner loop. - -The standard GPTQ inner loop launches ~10-15 CUDA kernels per column -(amax lookup, FP4 quantization, error computation, rank-1 update). -For ``block_size=128`` that is ~1 500 kernel launches per block, each with -~5-10 us of launch overhead dominating actual compute. - -This module fuses the entire inner loop into a **single** Triton kernel per -block. Rows are independent and map to Triton programs; columns are processed -sequentially inside each program so the rank-1 error update is carried forward -without synchronisation. - -Supported quantisation format: **NVFP4 static block quantisation** (two-level -scaling with per-group amax and a global amax). -""" - -import torch -import triton -import triton.language as tl - -__all__ = ["gptq_fused_block"] - -# -- NVFP4 constants used by the kernel ------------------------------------ -# Maximum representable FP4-E2M1 value (1 + 1 + 0.5 = 6.0 when decoded via -# the standard E2M1 table: {0, 0.5, 1, 1.5, 2, 3, 4, 6}). -_FP4_MAX = 6.0 -# FP8-E4M3 has max representable value 448. -_FP8_E4M3_MAX = 448.0 - - -@triton.jit -def _gptq_fused_block_kernel( - w_ptr, # [num_rows, BLOCK_SIZE] working weight block (in-place) - qw_ptr, # [num_rows, BLOCK_SIZE] output: quantized weights - err_ptr, # [num_rows, BLOCK_SIZE] output: quantization errors - amax_ptr, # [num_rows, num_groups] per-group amax, row-major - global_amax_ptr, # scalar float32 on device - hinv_ptr, # [BLOCK_SIZE, BLOCK_SIZE] upper Cholesky of H^{-1} - num_rows, - num_groups, - group_size: tl.constexpr, - block_start, # column offset of this block in the full weight matrix - n_cols, # actual columns in this block (may be < BLOCK_SIZE) - BLOCK_SIZE: tl.constexpr, -): - """One program per row; sequentially quantizes columns, propagating errors.""" - row = tl.program_id(0) - if row >= num_rows: - return - - # Base pointers for this row - w_base = w_ptr + row * BLOCK_SIZE - qw_base = qw_ptr + row * BLOCK_SIZE - err_base = err_ptr + row * BLOCK_SIZE - amax_row_base = amax_ptr + row * num_groups - - # Pre-compute global FP8 scale factors (constant across columns) - global_amax = tl.load(global_amax_ptr).to(tl.float32) - global_scale = global_amax / 6.0 # _FP4_MAX - fp8_inv_scale = tl.where(global_scale > 0.0, 1.0 / (448.0 / global_scale), 0.0) - - j_range = tl.arange(0, BLOCK_SIZE) - - for i in range(BLOCK_SIZE): - wi = tl.load(w_base + i) - - # -- Compute NVFP4 two-level scale for this column's group ----------- - col_idx = block_start + i - group_idx = col_idx // group_size - raw_amax = tl.load(amax_row_base + group_idx).to(tl.float32) - raw_scale = raw_amax / 6.0 # _FP4_MAX - - # FP8-quantize the block scale: scale * fp8_scale -> cast E4M3 -> back - fp8_scale = tl.where(global_scale > 0.0, 448.0 / global_scale, 1.0) - si = (raw_scale * fp8_scale).to(tl.float8e4nv).to(tl.float32) * fp8_inv_scale - - # Guard: replace zero / nan / inf scale with 1.0 - # NOTE: ``si != si`` is the standard NaN check in Triton (no math.isnan). - si_safe = tl.where( - (si == 0.0) | (si != si) | (tl.abs(si) == float("inf")), # noqa: PLR0124 - 1.0, - si, - ) - - # -- FP4-E2M1 fake quantization (nearest-round to 8 levels) ---------- - abs_scaled = tl.abs(wi) / si_safe - q_val = tl.where( - abs_scaled <= 0.25, - 0.0, - tl.where( - abs_scaled < 0.75, - 0.5, - tl.where( - abs_scaled <= 1.25, - 1.0, - tl.where( - abs_scaled < 1.75, - 1.5, - tl.where( - abs_scaled <= 2.5, - 2.0, - tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)), - ), - ), - ), - ), - ) - - qi = q_val * si_safe * tl.where(wi >= 0.0, 1.0, -1.0) - tl.store(qw_base + i, qi) - - # -- GPTQ error and rank-1 update ------------------------------------ - di = tl.load(hinv_ptr + i * BLOCK_SIZE + i) - err_i = (wi - qi) / di - tl.store(err_base + i, err_i) - - j_mask = (j_range > i) & (j_range < n_cols) - hinv_row = tl.load(hinv_ptr + i * BLOCK_SIZE + j_range, mask=j_mask, other=0.0) - w_rem = tl.load(w_base + j_range, mask=j_mask, other=0.0) - w_rem = w_rem - err_i * hinv_row - tl.store(w_base + j_range, w_rem, mask=j_mask) - - -def gptq_fused_block( - w_block: torch.Tensor, - amax_grouped: torch.Tensor, - global_amax: torch.Tensor, - h_inv_cho_blk: torch.Tensor, - group_size: int, - block_start: int, - n_cols: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Run the GPTQ column loop for one block in a single Triton kernel launch. - - Args: - w_block: Working weight block of shape ``[num_rows, block_size]`` (will be cloned). - amax_grouped: Per-group amax of shape ``[num_rows, num_groups]``. - global_amax: Scalar tensor with the global amax. - h_inv_cho_blk: Upper Cholesky factor of H^{-1}, shape ``[block_size, block_size]``. - group_size: NVFP4 quantization group size (typically 16). - block_start: Column offset of this block in the full weight matrix. - n_cols: Actual number of columns in this block (``<= block_size``). - - Returns: - Tuple of ``(qw_block, err_block)`` each of shape ``[num_rows, block_size]``. - """ - num_rows, block_size = w_block.shape - num_groups = amax_grouped.shape[1] - - w_block = w_block.contiguous() - amax_grouped = amax_grouped.contiguous() - h_inv_cho_blk = h_inv_cho_blk.contiguous() - - qw_block = torch.empty_like(w_block) - err_block = torch.empty_like(w_block) - - grid = (num_rows,) - with torch.cuda.device(w_block.device): - _gptq_fused_block_kernel[grid]( - w_block, - qw_block, - err_block, - amax_grouped, - global_amax, - h_inv_cho_blk, - num_rows, - num_groups, - group_size, - block_start, - n_cols, - BLOCK_SIZE=block_size, - ) - - return qw_block, err_block diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 7f8db20446..f42cde6687 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,14 +21,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import ( - _blockwise_weight_update_fused, - _blockwise_weight_update_unfused, - blockwise_weight_update, - prepare_hessian_inverse, - update_hessian, -) -from modelopt.torch.quantization.nn import NVFP4StaticQuantizer +from modelopt.torch.quantization.model_calib import GPTQHandle, update_hessian from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader @@ -163,8 +156,10 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): f"n_samples should be {expected_tokens_2}, got {n_samples}" ) - hessian = hessian.to(input.device) - blockwise_weight_update(model, hessian, block_size, 0.1) + handle = GPTQHandle(model, "linear") + handle.hessian = hessian.to(input.device) + handle.n_samples = n_samples + handle.quantize(block_size, 0.1) if expect_weight_change: # Weight must change as GPTQ updates weights to adjust for quantization error assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" @@ -196,7 +191,10 @@ def test_gptq_export_roundtrip(): hessian, n_samples = update_hessian(input_tensor, hessian, n_samples) hessian = hessian.to("cuda") - blockwise_weight_update(model, hessian, block_size, percdamp=0.1) + handle = GPTQHandle(model, "linear") + handle.hessian = hessian + handle.n_samples = n_samples + handle.quantize(block_size, percdamp=0.1) # Save the QDQ reference from the quantizer applied to GPTQ'd weights gptq_weight_shape = model.weight.data.shape @@ -309,87 +307,3 @@ def test_gptq_e2e_flow(quant_cfg): print( f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" ) - - -@pytest.mark.parametrize("dim", [256, 512]) -def test_fused_vs_unfused_nvfp4(dim): - """Verify that the fused Triton GPTQ kernel produces equivalent results to the unfused path. - - The fused kernel computes NVFP4 quantisation inline using Triton intrinsics, - which can differ slightly from the PyTorch-level quantiser path (different FP - rounding order). On real models (dim >= 4096) the relative MSE difference is - typically < 0.1%; at the smaller dims used here the tolerance is set to 20%. - """ - from modelopt.torch.quantization.model_calib import _promote_nvfp4_static_quantizers - - torch.manual_seed(RAND_SEED) - block_size = min(128, dim) - - # NVFP4_WEIGHT_ONLY_GPTQ_CFG uses *static* blocks, which get promoted to - # NVFP4StaticQuantizer — the prerequisite for the fused Triton path. - quant_cfg = copy.deepcopy(mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG) - quant_cfg["algorithm"] = "max" # calibrate only, don't run GPTQ - - model = torch.nn.Linear(dim, dim, bias=False).to("cuda") - model.name = "test_fused" - original_weight = model.weight.data.clone() - inp = torch.randn(4, 32, dim, device="cuda") - - mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(inp)) - - # Promote to NVFP4StaticQuantizer (normally done by gptq / sequential_calibrate) - n_promoted = _promote_nvfp4_static_quantizers(model) - assert n_promoted > 0, "Expected at least one quantizer to be promoted" - - quantizer = model.weight_quantizer - assert isinstance(quantizer, NVFP4StaticQuantizer), ( - f"Expected NVFP4StaticQuantizer, got {type(quantizer).__name__}" - ) - - # Restore original weight and compute Hessian - model.weight.data = original_weight.clone() - hessian = torch.zeros(dim, dim, dtype=torch.float32) - n_samples = 0 - hessian, n_samples = update_hessian(inp, hessian, n_samples) - hessian = hessian.to("cuda") - - # --- Run fused path --- - weight_fused = original_weight.float().clone() - num_rows, num_cols = weight_fused.shape - h_inv = prepare_hessian_inverse(hessian, weight_fused, percdamp=0.01) - _blockwise_weight_update_fused(weight_fused, h_inv, quantizer, num_rows, num_cols, block_size) - - # --- Run unfused path --- - weight_unfused = original_weight.float().clone() - h_inv_unfused = prepare_hessian_inverse(hessian, weight_unfused, percdamp=0.01) - _blockwise_weight_update_unfused( - weight_unfused, h_inv_unfused, quantizer, num_cols, block_size, None, False - ) - - # Both paths must produce non-trivial updates - assert not torch.equal(weight_fused, original_weight.float()), ( - "Fused path did not update weights" - ) - assert not torch.equal(weight_unfused, original_weight.float()), ( - "Unfused path did not update weights" - ) - - # Compare Hessian-weighted relative MSE - def _relative_mse(q, w, h): - delta = q - w - return (delta.mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6)).item() - - orig_f = original_weight.float() - mse_fused = _relative_mse(weight_fused, orig_f, hessian) - mse_unfused = _relative_mse(weight_unfused, orig_f, hessian) - - assert mse_fused > 0, "Fused MSE should be positive" - assert mse_unfused > 0, "Unfused MSE should be positive" - - # At small test dimensions, inline Triton FP4 rounding can diverge up to ~15% - # from the PyTorch path. On production-scale layers this drops below 0.1%. - relative_mse_diff = abs(mse_fused - mse_unfused) / max(mse_fused, mse_unfused) - assert relative_mse_diff < 0.20, ( - f"Fused ({mse_fused:.6e}) and unfused ({mse_unfused:.6e}) MSE differ by " - f"{relative_mse_diff:.2%}, expected < 20%" - ) From 7e235b4da9f36501126022096f582e388c9fbc7a Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:10:53 +0000 Subject: [PATCH 36/48] claude review + coderabbit review Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 36 ---------------------- modelopt/torch/quantization/model_calib.py | 25 ++++++--------- modelopt/torch/quantization/model_quant.py | 2 +- tests/gpu/torch/quantization/test_gptq.py | 2 +- 4 files changed, 11 insertions(+), 54 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index d096f8390e..c9ba9e7795 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -476,39 +476,6 @@ def _nvfp4_selective_quant_cfg( }, } -NVFP4_WEIGHT_ONLY_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, - "algorithm": {"method": "gptq", "use_sequential": True}, -} - -NVFP4_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": {"method": "gptq", "use_sequential": True}, -} - MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { "quant_cfg": { "*weight_quantizer": _nvfp4_quantizer, @@ -665,9 +632,6 @@ def _nvfp4_selective_quant_cfg( "NVFP4_AWQ_FULL_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", - "NVFP4_GPTQ_CFG", - "NVFP4_WEIGHT_ONLY_CFG", - "NVFP4_WEIGHT_ONLY_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c5b4bc2aab..7dcb42eb92 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1573,7 +1573,7 @@ def cleanup(self): """Unpatch the module's forward method.""" unpatch_forward_method(self.module, self.CACHE_NAME) - def quantize(self, block_size, percdamp): + def update_weights(self, block_size, percdamp): """Run GPTQ blockwise weight update on this module. Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, @@ -1673,6 +1673,8 @@ def update_hessian(input, hessian, n_samples): n_samples: Number of samples already processed Returns: Tuple of (updated_hessian, new_sample_count) + + Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. """ # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens input_flat = input.reshape(-1, input.shape[-1]).t().float() @@ -1689,19 +1691,6 @@ def update_hessian(input, hessian, n_samples): return hessian, n_samples -def _get_quantized_linear_layers(parent: nn.Module) -> list[tuple[str, nn.Module]]: - """Return (name, module) pairs for all quantized linear layers with enabled weight quantizers. - - Also sets ``module.name`` on each returned module for downstream logging. - """ - layers = [] - for name, module in parent.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - module.name = name - layers.append((name, module)) - return layers - - @torch.no_grad() def sequential_calibrate( model: nn.Module, @@ -1816,7 +1805,11 @@ def gptq( max_calibrate(model, forward_loop=forward_loop) _promote_nvfp4_static_quantizers(model) - quantized_layers = _get_quantized_linear_layers(model) + quantized_layers = [ + (n, m) + for n, m in model.named_modules() + if is_quantized_linear(m) and m.weight_quantizer.is_enabled + ] if not quantized_layers: print_rank_0("No quantized linear layers found, skipping GPTQ") return @@ -1833,7 +1826,7 @@ def gptq( print_rank_0("Updating weights using GPTQ algorithm...") for handle in gptq_handles.values(): - handle.quantize(block_size, percdamp) + handle.update_weights(block_size, percdamp) del gptq_handles if torch.cuda.is_available(): diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 4aa1ff46b4..d99c0e0fef 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -560,7 +560,7 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): config = mtq.get_auto_quantize_config(search_state) # [Optional] Customize algorithm if needed - config["algorithm"] = {"method": "gptq_lite", "sequential": True} + config["algorithm"] = {"method": "gptq", "sequential": True} # Reuse on the same model (e.g. run a longer calibration pass) model = mtq.quantize(model, config, forward_loop=calibrate_loop) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index f42cde6687..7d841ca141 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -278,7 +278,7 @@ def test_gptq_e2e_flow(quant_cfg): model.eval() quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"] = "gptq_lite" + quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} # Define quantizer/dataloader calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", From d1498be342043c2b9d71724dba3470aa6e767a4d Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:24:09 +0000 Subject: [PATCH 37/48] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 296 +++++++++++---------- modelopt/torch/quantization/model_quant.py | 3 - tests/gpu/torch/quantization/test_gptq.py | 42 +-- 3 files changed, 157 insertions(+), 184 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7dcb42eb92..2dc4485465 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1519,151 +1519,6 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -class GPTQHandle: - """Encapsulates per-module GPTQ state and operations. - - Owns the Hessian, patches the forward during collection, and contains - the blockwise weight-update logic. - - Instance attributes set during ``__init__``: - module, name, hessian, n_samples - - Instance attributes set during ``quantize``: - weight: float working copy of module weights (mutated in-place by update methods) - h_inv: upper-triangular Cholesky factor of the damped inverse Hessian - """ - - CACHE_NAME = "_forward_no_gptq_hessian" - - def __init__(self, module, name, offload_to_cpu=False): - self.module = module - self.name = name - in_features = module.weight.shape[-1] - device = module.weight.device - if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: - device = "cpu" - self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) - self.n_samples = 0 - # Set by quantize(); listed here for documentation. - self.weight: torch.Tensor | None = None - self.h_inv: torch.Tensor | None = None - - def setup(self): - """Patch the module's forward to accumulate Hessian during the collection pass.""" - gptq_handle = self - - def hessian_forward(self, input, *args, **kwargs): - inp = input.to_local() if hasattr(input, "to_local") else input - if self.input_quantizer is not None and self.input_quantizer.is_enabled: - hessian_input = self.input_quantizer(inp) - else: - hessian_input = inp - gptq_handle.hessian, gptq_handle.n_samples = update_hessian( - hessian_input, gptq_handle.hessian, gptq_handle.n_samples - ) - - self.weight_quantizer.disable() - out = self._forward_no_gptq_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() - return out - - bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) - - def cleanup(self): - """Unpatch the module's forward method.""" - unpatch_forward_method(self.module, self.CACHE_NAME) - - def update_weights(self, block_size, percdamp): - """Run GPTQ blockwise weight update on this module. - - Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, - logs MSE, and writes the result back to the module. - """ - hessian = self.hessian.to(self.module.weight.device) - self.weight = self.module.weight.data.float().clone() - self._prepare_hessian_inverse(hessian, percdamp) - - self._blockwise_update(block_size) - - self._print_mse_error(hessian) - self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( - self.module.weight.data.dtype - ) - - # ------------------------------------------------------------------ - # Quantize helpers — all read from self.module, self.weight, self.h_inv - # ------------------------------------------------------------------ - - def _prepare_hessian_inverse(self, hessian, percdamp): - """Compute damped inverse Hessian and store as ``self.h_inv``. - - Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the - Hessian before inversion, matching the FP-Quant reference: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 - """ - assert self.weight is not None, "_prepare_hessian_inverse called before quantize()" - h = hessian.clone() - zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) - - h[zero_cols, :] = 0 - h[:, zero_cols] = 0 - h[zero_cols, zero_cols] = 1 - - damp = percdamp * torch.mean(torch.diag(h)) - diag_indices = torch.arange(h.shape[0], device=h.device) - h[diag_indices, diag_indices] += damp - - try: - h = torch.cholesky_inverse(torch.linalg.cholesky(h)) - self.h_inv = torch.linalg.cholesky(h, upper=True) - except (RuntimeError, torch.linalg.LinAlgError): - print_rank_0("Warning: Hessian is not positive definite, using identity matrix") - self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - - def _blockwise_update(self, block_size): - """Column-wise GPTQ update using full-matrix QDQ. - - For each column, quantizes the full weight matrix via the quantizer and - extracts the quantized column. This is the standard GPTQ approach. - - Reads/writes ``self.weight`` and ``self.h_inv`` in-place. - """ - assert self.weight is not None and self.h_inv is not None, ( - "_blockwise_update called before _prepare_hessian_inverse()" - ) - quantizer = self.module.weight_quantizer - num_cols = self.weight.shape[1] - - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - n_cols_blk = block_end - block_start - h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] - - wblk = self.weight.clone() - errs = torch.zeros_like(wblk[:, block_start:block_end]) - - for i in range(n_cols_blk): - w_ci = wblk[:, block_start + i] - d = h_inv_cho_blk[i, i] - qdq = quantizer(wblk) - self.weight[:, block_start + i] = qdq[:, block_start + i] - err = (w_ci - qdq[:, block_start + i]) / d - wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err - - self.weight[:, block_end:].addmm_( - errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 - ) - - def _print_mse_error(self, hessian): - """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" - w_orig = self.module.weight.float() - delta = self.weight - w_orig - mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) - suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" - print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") - - def update_hessian(input, hessian, n_samples): """Update hessian matrix with new input samples using incremental formula. @@ -1800,6 +1655,155 @@ def gptq( percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ + + class GPTQHelper: + """Encapsulates per-module GPTQ state and operations. + + Owns the Hessian, patches the forward during collection, and contains + the blockwise weight-update logic. + + Instance attributes set during ``__init__``: + module, name, hessian, n_samples + + Instance attributes set during ``update_weights``: + weight: float working copy of module weights (mutated in-place by update methods) + h_inv: upper-triangular Cholesky factor of the damped inverse Hessian + """ + + CACHE_NAME = "_forward_no_gptq_hessian" + + def __init__(self, module, name, offload_to_cpu=False): + self.module = module + self.name = name + in_features = module.weight.shape[-1] + device = module.weight.device + if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + device = "cpu" + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.n_samples = 0 + # Set by update_weights(); listed here for documentation. + self.weight: torch.Tensor | None = None + self.h_inv: torch.Tensor | None = None + + def setup(self): + """Patch the module's forward to accumulate Hessian during the collection pass.""" + gptq_helper = self + + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp + gptq_helper.hessian, gptq_helper.n_samples = update_hessian( + hessian_input, gptq_helper.hessian, gptq_helper.n_samples + ) + + self.weight_quantizer.disable() + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) + + def cleanup(self): + """Unpatch the module's forward method.""" + unpatch_forward_method(self.module, self.CACHE_NAME) + + def update_weights(self, block_size, percdamp): + """Run GPTQ blockwise weight update on this module. + + Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, + logs MSE, and writes the result back to the module. + """ + hessian = self.hessian.to(self.module.weight.device) + self.weight = self.module.weight.data.float().clone() + self._prepare_hessian_inverse(hessian, percdamp) + + self._blockwise_update(block_size) + + self._print_mse_error(hessian) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype + ) + + # ------------------------------------------------------------------ + # Quantize helpers — all read from self.module, self.weight, self.h_inv + # ------------------------------------------------------------------ + + def _prepare_hessian_inverse(self, hessian, percdamp): + """Compute damped inverse Hessian and store as ``self.h_inv``. + + Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the + Hessian before inversion, matching the FP-Quant reference: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + assert self.weight is not None, ( + "_prepare_hessian_inverse called before update_weights()" + ) + h = hessian.clone() + zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) + + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 + + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp + + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + + def _blockwise_update(self, block_size): + """Column-wise GPTQ update using full-matrix QDQ. + + For each column, quantizes the full weight matrix via the quantizer and + extracts the quantized column. This is the standard GPTQ approach. + + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. + """ + assert self.weight is not None and self.h_inv is not None, ( + "_blockwise_update called before _prepare_hessian_inverse()" + ) + quantizer = self.module.weight_quantizer + num_cols = self.weight.shape[1] + + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start + h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] + + wblk = self.weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols_blk): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = quantizer(wblk) + self.weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + + self.weight[:, block_end:].addmm_( + errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 + ) + + def _print_mse_error(self, hessian): + """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" + w_orig = self.module.weight.float() + delta = self.weight - w_orig + mse = (delta).mm(hessian).mul(delta).mean() / ( + w_orig.mm(hessian).mul(w_orig).mean() + 1e-6 + ) + suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" + print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") + total_start = time.time() max_calibrate(model, forward_loop=forward_loop) @@ -1814,7 +1818,7 @@ def gptq( print_rank_0("No quantized linear layers found, skipping GPTQ") return - gptq_handles = {name: GPTQHandle(m, name, offload_to_cpu=True) for name, m in quantized_layers} + gptq_handles = {name: GPTQHelper(m, name, offload_to_cpu=True) for name, m in quantized_layers} for handle in gptq_handles.values(): handle.setup() diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index d99c0e0fef..7152efacfc 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -559,9 +559,6 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): # Or use the original result config = mtq.get_auto_quantize_config(search_state) - # [Optional] Customize algorithm if needed - config["algorithm"] = {"method": "gptq", "sequential": True} - # Reuse on the same model (e.g. run a longer calibration pass) model = mtq.quantize(model, config, forward_loop=calibrate_loop) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 7d841ca141..ddfe0fc892 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,7 +21,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import GPTQHandle, update_hessian +from modelopt.torch.quantization.model_calib import gptq, update_hessian from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader @@ -127,39 +127,20 @@ def test_update_hessian(): def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): model = torch.nn.Linear(dim, dim).to("cuda") model.weight.data = model_weight - model.name = "linear" original_weight = model_weight.clone() - input = torch.randn(2, 16, dim).to("cuda") - hessian = torch.zeros(dim, dim).to("cpu") - n_samples = 0 + input_tensor = torch.randn(2, 16, dim).to("cuda") quant_cfg = mtq.NVFP4_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input)) + mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input_tensor)) # Get qdq weight q_dq_weight = model.weight_quantizer(model.weight.data) - # Restore original weight + # Restore original weight before GPTQ model.weight.data = original_weight.clone() - hessian, n_samples = update_hessian(input, hessian, n_samples) - - # Verify n_samples counts total tokens (batch * seq_len) after flattening - expected_tokens = input.shape[0] * input.shape[1] # 2 * 16 = 32 - assert n_samples == expected_tokens, f"n_samples should be {expected_tokens}, got {n_samples}" - - # Perform another forward pass to update hessian matrix - input_2 = torch.randn(3, 16, dim).to("cuda") - hessian, n_samples = update_hessian(input_2, hessian, n_samples) - expected_tokens_2 = expected_tokens + input_2.shape[0] * input_2.shape[1] # 32 + 48 = 80 - assert n_samples == expected_tokens_2, ( - f"n_samples should be {expected_tokens_2}, got {n_samples}" - ) - - handle = GPTQHandle(model, "linear") - handle.hessian = hessian.to(input.device) - handle.n_samples = n_samples - handle.quantize(block_size, 0.1) + # Run GPTQ through the public API + gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) if expect_weight_change: # Weight must change as GPTQ updates weights to adjust for quantization error assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" @@ -175,7 +156,6 @@ def test_gptq_export_roundtrip(): # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers model = torch.nn.Linear(dim, dim).to("cuda") - model.name = "linear" original_weight = model.weight.data.clone() input_tensor = torch.randn(2, 16, dim).to("cuda") quant_cfg = mtq.NVFP4_DEFAULT_CFG @@ -186,15 +166,7 @@ def test_gptq_export_roundtrip(): model.weight.data = original_weight.clone() # Step 2: Perform GPTQ — compute Hessian and update weights - hessian = torch.zeros(dim, dim, dtype=torch.float32) - n_samples = 0 - hessian, n_samples = update_hessian(input_tensor, hessian, n_samples) - hessian = hessian.to("cuda") - - handle = GPTQHandle(model, "linear") - handle.hessian = hessian - handle.n_samples = n_samples - handle.quantize(block_size, percdamp=0.1) + gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) # Save the QDQ reference from the quantizer applied to GPTQ'd weights gptq_weight_shape = model.weight.data.shape From d8b1d930fbab947f33d16b9ee9726eef4332c211 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:27:55 +0000 Subject: [PATCH 38/48] stray changes removed Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index bfe2c861c5..5620ddf6a4 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -107,7 +107,6 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, } @@ -1003,7 +1002,6 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) - export_quantized( args, full_model, @@ -1162,7 +1160,6 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( "--low_memory_mode", help=( From 19fc0c220f05f0b04ab848a55a07ff9a50e2d07c Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:50:43 +0000 Subject: [PATCH 39/48] Address PR comments Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 4 +- modelopt/torch/quantization/model_calib.py | 53 ++++++++----------- .../torch/quantization/utils/core_utils.py | 43 +++++++++++++++ tests/gpu/torch/quantization/test_gptq.py | 20 +------ 4 files changed, 69 insertions(+), 51 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index c9ba9e7795..be11108b28 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1321,14 +1321,14 @@ class GPTQConfig(QuantizeAlgorithmConfig): """ method: Literal["gptq"] = ModeloptField("gptq") - percdamp: float | None = ModeloptField( + percdamp: float = ModeloptField( default=0.01, gt=0.0, le=1.0, title="Percentage damping factor.", description="The percentage of average Hessian diagonal used for damping.", ) - block_size: int | None = ModeloptField( + block_size: int = ModeloptField( default=128, title="Block size for GPTQ weight update.", description="""The block size for GPTQ weight update, which must be a multiple of the diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 2dc4485465..3c11058507 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -39,12 +39,14 @@ from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( disable_calib, + disabled_weight_quantizers, enable_fake_quant, enable_quant, enable_weight_access_and_writeback, is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, weight_attr_names, @@ -1596,33 +1598,6 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") -def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: - """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. - - After max calibration sets per-block amax values, NVFP4 static quantizers - need to be promoted so they use the two-level scaling path (global amax + - per-block amax) instead of the generic E4M3 path. - - Returns the number of quantizers converted. - """ - converted = 0 - for _name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: - initial_amax = module._amax.clone().detach() - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) - converted += 1 - return converted - - @torch.no_grad() def gptq( model: nn.Module, @@ -1699,9 +1674,8 @@ def hessian_forward(self, input, *args, **kwargs): hessian_input, gptq_helper.hessian, gptq_helper.n_samples ) - self.weight_quantizer.disable() out = self._forward_no_gptq_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() + return out bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) @@ -1710,6 +1684,12 @@ def cleanup(self): """Unpatch the module's forward method.""" unpatch_forward_method(self.module, self.CACHE_NAME) + def free(self): + """Release Hessian and working tensors to reclaim memory.""" + self.hessian = None + self.weight = None + self.h_inv = None + def update_weights(self, block_size, percdamp): """Run GPTQ blockwise weight update on this module. @@ -1771,6 +1751,14 @@ def _blockwise_update(self, block_size): "_blockwise_update called before _prepare_hessian_inverse()" ) quantizer = self.module.weight_quantizer + block_sizes = getattr(quantizer, "block_sizes", None) + if block_sizes is not None: + group_size = block_sizes.get(-1) + if group_size is not None and block_size % group_size != 0: + raise ValueError( + f"GPTQ block_size ({block_size}) must be divisible by the quantizer" + f" group_size ({group_size})" + ) num_cols = self.weight.shape[1] for block_start in range(0, num_cols, block_size): @@ -1807,7 +1795,7 @@ def _print_mse_error(self, hessian): total_start = time.time() max_calibrate(model, forward_loop=forward_loop) - _promote_nvfp4_static_quantizers(model) + promote_nvfp4_static_quantizers(model) quantized_layers = [ (n, m) @@ -1823,7 +1811,9 @@ def _print_mse_error(self, hessian): handle.setup() print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") - forward_loop(model) + + with disabled_weight_quantizers(model): + forward_loop(model) for handle in gptq_handles.values(): handle.cleanup() @@ -1831,6 +1821,7 @@ def _print_mse_error(self, hessian): print_rank_0("Updating weights using GPTQ algorithm...") for handle in gptq_handles.values(): handle.update_weights(block_size, percdamp) + handle.free() del gptq_handles if torch.cuda.is_available(): diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 4340b8dc1f..c3ca5b661d 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -27,6 +27,7 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -136,6 +137,33 @@ def convert_quantization_axis_to_reduce_axis(input, axis): return reduce_axis +def promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted + + @torch.no_grad() def reduce_amax(input, axis=None, keepdims=True, squeeze_scalar=True): """Compute the absolute maximum value of a tensor. @@ -710,6 +738,21 @@ def disable_calib(quantizer): quantizer._if_calib = original_if_calib +@contextmanager +def disabled_weight_quantizers(model: nn.Module): + """Disable weight quantizers during hessian collection.""" + disabled_modules = [] + for module in model.modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + module.weight_quantizer.disable() + disabled_modules.append(module) + try: + yield + finally: + for module in disabled_modules: + module.weight_quantizer.enable() + + @contextmanager def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): """Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule. diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index ddfe0fc892..e229c89e5a 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -204,26 +204,10 @@ def test_gptq_export_roundtrip(): num_mismatched = (diff > 1e-3).sum().item() total_elements = diff.numel() - print("\n--- Diff Stats ---") - print(f" Max diff: {max_diff}") - print(f" Mean diff: {diff.mean().item()}") - print(f" Median diff: {diff.median().item()}") - print(f" Std diff: {diff.std().item()}") - print( - f" Mismatched (>1e-3): {num_mismatched}/{total_elements} " - f"({100 * num_mismatched / total_elements:.2f}%)" - ) - print( - f" Max diff at [{max_diff_row}, {max_diff_col}]: " - f"deq={deq_weight[max_diff_row, max_diff_col].item()}, " - f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()}" - ) - assert torch.allclose(deq_weight, qdq_ref.to(torch.bfloat16), atol=1e-2), ( f"Dequantized weight does not match QDQ reference. " - f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}] " - f"(deq={deq_weight[max_diff_row, max_diff_col].item()}, " - f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()})" + f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}], " + f"mismatched (>1e-3): {num_mismatched}/{total_elements}" ) From 068e8a990326fef130ea7df2142583127f3711a2 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:21:03 +0000 Subject: [PATCH 40/48] fixed circular import issue Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 30 +++++++++++++++++-- .../torch/quantization/utils/core_utils.py | 28 ----------------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3c11058507..817e3a0b21 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -46,7 +46,6 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, - promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, weight_attr_names, @@ -1598,6 +1597,33 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") +def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted + + @torch.no_grad() def gptq( model: nn.Module, @@ -1795,7 +1821,7 @@ def _print_mse_error(self, hessian): total_start = time.time() max_calibrate(model, forward_loop=forward_loop) - promote_nvfp4_static_quantizers(model) + _promote_nvfp4_static_quantizers(model) quantized_layers = [ (n, m) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index c3ca5b661d..9893d4e51a 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -27,7 +27,6 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate -from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -137,33 +136,6 @@ def convert_quantization_axis_to_reduce_axis(input, axis): return reduce_axis -def promote_nvfp4_static_quantizers(model: nn.Module) -> int: - """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. - - After max calibration sets per-block amax values, NVFP4 static quantizers - need to be promoted so they use the two-level scaling path (global amax + - per-block amax) instead of the generic E4M3 path. - - Returns the number of quantizers converted. - """ - converted = 0 - for _name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: - initial_amax = module._amax.clone().detach() - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) - converted += 1 - return converted - - @torch.no_grad() def reduce_amax(input, axis=None, keepdims=True, squeeze_scalar=True): """Compute the absolute maximum value of a tensor. From 2930b554f8e0e48a5ff30c131b8b09b77313dc3a Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:23:07 +0000 Subject: [PATCH 41/48] tested e2e on qwen3-8b Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 196 +++++++++++++++++++-- modelopt/torch/quantization/model_calib.py | 156 +++++++++++++++- 2 files changed, 334 insertions(+), 18 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5620ddf6a4..e017390481 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,6 +15,7 @@ import argparse import copy +import os import random import time import warnings @@ -61,6 +62,11 @@ ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.metrics_backup import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.utils.dataset_utils import ( @@ -747,7 +753,7 @@ def pre_quantize( allow_fallback=False, ) else: - generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=2) return preview_input_ids, generated_ids_before_ptq @@ -786,7 +792,7 @@ def post_quantize( pass elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=2) elif is_nemotron_vl_model and tokenizer is not None: generated_ids_after_ptq = run_nemotron_vl_preview( full_model, @@ -910,6 +916,9 @@ def quantize_main( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) + mse_logger = None + mse_data = None + if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( "Auto quantization needs multiple quantization format." @@ -937,10 +946,17 @@ def quantize_main( "Plain quantization supports only one quantization format." ) - assert args.qformat in QUANT_CFG_CHOICES, ( - f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}" - ) - quant_cfg = QUANT_CFG_CHOICES[args.qformat] + if args.qformat in QUANT_CFG_CHOICES: + quant_cfg = QUANT_CFG_CHOICES[args.qformat] + else: + # Fallback: resolve dynamically registered configs from the mtq namespace + # (e.g., PSX LUTS configs registered by modelopt-internal plugins). + quant_cfg = getattr(mtq, args.qformat, None) + assert quant_cfg is not None, ( + f"Unsupported quantization format: {args.qformat}, " + f"not found in built-in choices {list(QUANT_CFG_CHOICES.keys())} " + f"or in the mtq namespace (check that the required plugin is installed)." + ) quant_cfg = build_quant_cfg( args.qformat, @@ -976,7 +992,64 @@ def quantize_main( quant_cfg = copy.deepcopy(quant_cfg) _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) - if args.qformat in QUANT_CFG_CHOICES: + # Collect original (unquantized) activations before quantization modifies the model + if getattr(args, "measure_activation_mse", False): + n_mse = getattr(args, "activation_mse_max_samples", 16) + mse_save_dir = getattr(args, "activation_mse_save_dir", None) + mse_input_path = getattr(args, "activation_mse_input_path", None) + + # Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader + mse_data = None + if mse_input_path is not None: + if mse_input_path.endswith(".json"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .json file: {mse_input_path}") + texts = ActivationMSELogger.load_raw_text(mse_input_path) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + else: + assert tokenizer is not None, ( + "--activation_mse_input_path with .json requires a tokenizer to decode" + ) + print(f"Creating MSE input data .json file: {mse_input_path}") + texts = ActivationMSELogger.materialize_raw_text( + calib_dataloader, + mse_input_path, + tokenizer=tokenizer, + max_samples=n_mse, + ) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + elif mse_input_path.endswith(".pt"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.load_data(mse_input_path) + else: + print(f"Creating MSE input data .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.materialize_data( + calib_dataloader, + mse_input_path, + max_samples=n_mse, + ) + else: + raise ValueError( + f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}" + ) + + if mse_data is None: + mse_data = calib_dataloader + + mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) + print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") + mse_logger.collect(language_model, mse_data, phase="original") + + if args.qformat in QUANT_CFG_CHOICES or hasattr(mtq, args.qformat): mono_quantize( args, quant_cfg, @@ -1002,15 +1075,54 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) - export_quantized( - args, - full_model, - language_model, - model_type, - tokenizer, - default_padding_side, - default_pad_token, - ) + + if mse_logger is not None: + import gc + + print("Collecting quantized activations for MSE...") + mse_logger.collect(language_model, mse_data, phase="quantized") + + mse_logger.compute_mse() + print(mse_logger.summary()) + + if getattr(args, "activation_mse_save_dir", None): + mse_logger.save() + + del mse_logger, mse_data + gc.collect() + torch.cuda.empty_cache() + + if getattr(args, "eval_perplexity", False) and tokenizer is not None: + if getattr(args, "fold_weights", False): + print("Folding weights before perplexity evaluation...") + mtq.fold_weight(language_model) + seq_len = getattr(args, "eval_perplexity_seq_len", 2048) + eval_data = get_wikitext2(tokenizer, seq_len) + ppl = compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + # Plugin-registered configs (e.g. PSX LUTS from modelopt-internal) are not exportable + # via the standard TRT-LLM / HF export paths. Fall back to save_pretrained(). + if args.qformat not in QUANT_CFG_CHOICES and hasattr(mtq, args.qformat): + print( + f"qformat '{args.qformat}' is a plugin-registered config and is not exportable " + f"via the standard export pipeline. Saving with save_pretrained() instead." + ) + export_path = args.export_path + full_model.save_pretrained(export_path) + if tokenizer is not None: + tokenizer.save_pretrained(export_path) + print(f"Quantized model saved to: {export_path}") + else: + export_quantized( + args, + full_model, + language_model, + model_type, + tokenizer, + default_padding_side, + default_pad_token, + ) def parse_args() -> argparse.Namespace: @@ -1219,6 +1331,58 @@ def parse_args() -> argparse.Namespace: ), ) + parser.add_argument( + "--eval_perplexity", + action=argparse.BooleanOptionalAction, + default=False, + help="Evaluate Wikitext-2 perplexity after quantization (before export).", + ) + parser.add_argument( + "--eval_perplexity_seq_len", + type=int, + default=2048, + help="Sequence length for perplexity evaluation (default: 2048).", + ) + parser.add_argument( + "--measure_activation_mse", + action=argparse.BooleanOptionalAction, + default=False, + help="Measure per-layer activation MSE (original vs quantized) after quantization.", + ) + parser.add_argument( + "--activation_mse_max_samples", + type=int, + default=16, + help="Max calibration samples for activation MSE (default: 16).", + ) + parser.add_argument( + "--activation_mse_save_dir", + type=str, + default=None, + help="Directory to save activation MSE results. If not set, results are only printed.", + ) + parser.add_argument( + "--activation_mse_input_path", + type=str, + default=None, + help=( + "Path to frozen MSE input data. Supports two formats:\n" + " .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes " + "with the current model's tokenizer; if not, decodes calibration data to text and saves.\n" + " .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; " + "if not, materializes from calibration data and saves." + ), + ) + parser.add_argument( + "--fold_weights", + action=argparse.BooleanOptionalAction, + default=False, + help=( + "Fold quantized weights before collecting activation MSE. " + "Speeds up the quantized forward pass by replacing weights in-place " + "and disabling fake-quant, but permanently mutates the weights." + ), + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 817e3a0b21..e9ddaf6fd0 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1581,6 +1581,8 @@ def sequential_calibrate( try: for layer_idx, layer in enumerate(transformer_layers): print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") + # Store layer_idx so gptq/GPTQHelper can access it for debugging + layer._seq_calib_layer_idx = layer_idx layer_inputs = input_getter.get_input_activations(layer, forward_loop) def _layer_forward_loop(m, _inputs=layer_inputs): @@ -1762,8 +1764,21 @@ def _prepare_hessian_inverse(self, hessian, percdamp): h = torch.cholesky_inverse(torch.linalg.cholesky(h)) self.h_inv = torch.linalg.cholesky(h, upper=True) except (RuntimeError, torch.linalg.LinAlgError): - print_rank_0("Warning: Hessian is not positive definite, using identity matrix") - self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + # Retry with 10x more dampening (matches reference implementation) + print_rank_0( + f"Warning: Hessian not positive definite for {self.name}, " + "retrying with 10x dampening" + ) + h[diag_indices, diag_indices] += damp * 10 + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0( + f"Warning: Hessian still not positive definite for {self.name}, " + "using identity matrix" + ) + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) def _blockwise_update(self, block_size): """Column-wise GPTQ update using full-matrix QDQ. @@ -1771,6 +1786,10 @@ def _blockwise_update(self, block_size): For each column, quantizes the full weight matrix via the quantizer and extracts the quantized column. This is the standard GPTQ approach. + For PSX LUTS vector quantizers, uses a two-phase approach: + 1. Compute scales once per outer block via dynamic quantization + 2. Use static (pre-scaled) quantization in the inner loop + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. """ assert self.weight is not None and self.h_inv is not None, ( @@ -1785,6 +1804,21 @@ def _blockwise_update(self, block_size): f"GPTQ block_size ({block_size}) must be divisible by the quantizer" f" group_size ({group_size})" ) + + # Detect PSX LUTS vector quantizer for the fast static-scale path + is_psx_luts_vq = ( + getattr(quantizer, "backend", None) == "psx_luts" + and quantizer.backend_extra_args.get("lut_type", "vector_lut") == "vector_lut" + ) + + if is_psx_luts_vq: + self._blockwise_update_psx_luts(block_size, quantizer) + else: + self._blockwise_update_default(block_size, quantizer) + + def _blockwise_update_default(self, block_size, quantizer): + """Standard GPTQ blockwise update (full QDQ per column).""" + assert self.weight is not None and self.h_inv is not None num_cols = self.weight.shape[1] for block_start in range(0, num_cols, block_size): @@ -1808,6 +1842,118 @@ def _blockwise_update(self, block_size): errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 ) + @staticmethod + def _dynamic_blockwise_vector_quantization( + x, vector_lut, block_size=16, scale_type="e4m3", return_scales=False + ): + """Dynamic VQ: computes scales from input, returns quantized output (and optionally scales).""" + from luts import clip_vector_scalesign_fast + + y = clip_vector_scalesign_fast( + x, + vector_lut, + block_size, + scale_type, + scale_algo="max", + sign_scale=True, + return_scales=return_scales, + ) + if return_scales: + return y[0], y[1] + return y + + @staticmethod + def _static_blockwise_vector_quantization(x, vector_lut, scales): + """Static VQ: uses pre-computed scales, returns quantized output.""" + from luts import clip_vector_prescaled + + return clip_vector_prescaled(x, vector_lut, scales) + + def _blockwise_update_psx_luts(self, block_size, quantizer): + """GPTQ blockwise update for PSX LUTS vector quantizers. + + Uses dynamic_blockwise_vector_quantization to pre-compute scales, + then static_blockwise_vector_quantization inside the GPTQ loop. + + Follows the 3-loop structure from the VQ GPTQ reference + (adaptive_rounding.py: gptq_quantize_scaled_vq). + """ + extra_args = quantizer.backend_extra_args + encode_format = quantizer.num_bits + encode_path = extra_args.get("encode_path", "") + if encode_path and not encode_path.endswith("/"): + encode_path += "/" + quant_block_size = extra_args.get("block_sizes", 16) + scale_type = extra_args.get("scale_type", "e4m3") + + # Load the vector LUT codebook + import luts + + if "sorted" not in encode_format: + values, _ = luts.encode(encode_format, path=encode_path, norm=False, cuda=True) + else: + sorted_codebook = torch.load( + encode_path + encode_format + ".pt", map_location="cpu" + ) + values = sorted_codebook["sorted_values"].cuda() + + values = values.to(torch.float) + vector_size = values.shape[1] + assert self.weight is not None and self.h_inv is not None + out_features, num_cols = self.weight.shape + + assert block_size % quant_block_size == 0, ( + f"GPTQ block_size ({block_size}) must be a multiple of " + f"quant_block_size ({quant_block_size})" + ) + + # Outside GPTQ loop: dynamic quantization to get scales + _, scales = self._dynamic_blockwise_vector_quantization( + self.weight, + values, + block_size=quant_block_size, + scale_type=scale_type, + return_scales=True, + ) + + # Reshape flat scales to 2D for per-vector-group extraction + n_scale_blocks_per_row = num_cols // quant_block_size + scales_2d = scales.reshape(out_features, n_scale_blocks_per_row) + + w = self.weight.clone() + q = torch.zeros_like(w) + h_inv = self.h_inv + + for i in range(0, num_cols, block_size): + j_end = min(i + block_size, num_cols) + e = torch.zeros(out_features, j_end - i, dtype=w.dtype, device=w.device) + + for j in range(i, j_end, vector_size): + d = min(vector_size, j_end - j) + sb = j // quant_block_size + s = scales_2d[:, sb].contiguous() + + # Inside GPTQ loop: static quantization with pre-computed scales + sub_vec = w[:, j : j + d].contiguous() + if d == vector_size: + q_sub = self._static_blockwise_vector_quantization(sub_vec, values, s) + else: + padded = torch.nn.functional.pad(sub_vec, (0, vector_size - d)) + q_sub = self._static_blockwise_vector_quantization(padded, values, s)[:, :d] + + q[:, j : j + d] = q_sub + + for k in range(d): + col = j + k + err = (w[:, col] - q[:, col]) / h_inv[col, col] + e[:, col - i] = err + w[:, col:j_end] -= err.unsqueeze(1) * h_inv[col, col:j_end].unsqueeze(0) + + if j_end < num_cols: + w[:, j_end:] -= e @ h_inv[i:j_end, j_end:] + + self.weight = q + def _print_mse_error(self, hessian): """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" w_orig = self.module.weight.float() @@ -1847,6 +1993,12 @@ def _print_mse_error(self, hessian): print_rank_0("Updating weights using GPTQ algorithm...") for handle in gptq_handles.values(): handle.update_weights(block_size, percdamp) + wq = handle.module.weight_quantizer + backend = getattr(wq, "backend", None) + print_rank_0(f" [{handle.name}] weight_quantizer.backend={backend}") + if backend == "psx_luts": + wq.disable() + print_rank_0(f" Disabled weight_quantizer for {handle.name}") handle.free() del gptq_handles From b35bc850e0097fe994abed77be8a209c219c3a69 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:40:10 +0000 Subject: [PATCH 42/48] tested e2e on qwen3-8b Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e9ddaf6fd0..0752c7b41a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1878,6 +1878,7 @@ def _blockwise_update_psx_luts(self, block_size, quantizer): Follows the 3-loop structure from the VQ GPTQ reference (adaptive_rounding.py: gptq_quantize_scaled_vq). """ + print_rank_0(f" [{self.name}] Using PSX LUTS GPTQ path (v2)") extra_args = quantizer.backend_extra_args encode_format = quantizer.num_bits encode_path = extra_args.get("encode_path", "") From 0f621cdb8de1a532a59cda008c2d608de580f3e2 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 6 Apr 2026 02:59:59 +0000 Subject: [PATCH 43/48] latest run with export Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 151 +++++++++++++++++--- examples/vllm_serve/fakequant_worker.py | 10 +- examples/vllm_serve/vllm_serve_fakequant.py | 1 + modelopt/torch/quantization/config.py | 9 ++ modelopt/torch/quantization/model_calib.py | 26 ++++ modelopt/torch/utils/dataset_utils.py | 19 ++- 6 files changed, 197 insertions(+), 19 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e017390481..ca6bfbdfb4 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -61,6 +61,7 @@ save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model +from modelopt.torch.export.plugins.vllm_fakequant_hf import export_hf_vllm_fq_checkpoint from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration from modelopt.torch.quantization.metrics_backup import ( ActivationMSELogger, @@ -247,6 +248,91 @@ def make_calib_dataloader( return calib_dataloader, first_text_speech_dataset +def make_mse_holdout_dataloader( + args: argparse.Namespace, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, +) -> DataLoader: + """Create a hold-out dataloader for activation MSE from the same dataset as calibration. + + Samples are drawn from the same dataset/splits but starting *after* the calibration + region, so there is zero overlap. The skip count per split equals + ``calib_size // num_splits`` (matching how ``get_dataset_samples`` divides samples). + """ + from modelopt.torch.utils.dataset_utils import SUPPORTED_DATASET_CONFIG + + dataset_names = args.dataset + calib_sizes = args.calib_size + n_mse = getattr(args, "activation_mse_max_samples", 16) + + # Compute per-split skip: calib samples per dataset / number of splits for that dataset + skip_per_dataset = [] + for ds_name, cs in zip(dataset_names, calib_sizes): + if ds_name in SUPPORTED_DATASET_CONFIG: + n_splits = len(SUPPORTED_DATASET_CONFIG[ds_name]["config"].get("split", [None])) + else: + n_splits = 1 + skip_per_dataset.append(cs // max(n_splits, 1)) + + # Use the max skip across datasets (all datasets share the same skip_samples param) + skip = max(skip_per_dataset) + + # Number of hold-out samples per dataset, proportional to calib_size + total_calib = sum(calib_sizes) + holdout_sizes = [max(1, int(n_mse * cs / total_calib)) for cs in calib_sizes] + # Ensure we get exactly n_mse total + holdout_sizes[-1] = n_mse - sum(holdout_sizes[:-1]) + + print( + f"Creating MSE hold-out dataloader: skip_per_split={skip}, " + f"holdout_sizes={dict(zip(dataset_names, holdout_sizes))}" + ) + + holdout_dataloader = get_dataset_dataloader( + dataset_name=dataset_names, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=holdout_sizes, + device=device, + skip_samples=skip, + ) + return holdout_dataloader + + +def verify_no_overlap( + calib_dataloader: DataLoader, + holdout_dataloader: DataLoader, +) -> None: + """Verify that calibration and hold-out dataloaders have no overlapping samples. + + Compares SHA-256 hashes of each row of input_ids across both dataloaders. + Raises AssertionError if any overlap is found. + """ + import hashlib + + def _collect_hashes(dataloader: DataLoader) -> set[str]: + hashes = set() + for batch in dataloader: + ids = batch["input_ids"] if isinstance(batch, dict) else batch + for row in ids: + h = hashlib.sha256(row.cpu().numpy().tobytes()).hexdigest() + hashes.add(h) + return hashes + + calib_hashes = _collect_hashes(calib_dataloader) + holdout_hashes = _collect_hashes(holdout_dataloader) + overlap = calib_hashes & holdout_hashes + + assert len(overlap) == 0, ( + f"Found {len(overlap)} overlapping samples between calibration and MSE hold-out data! " + f"This invalidates the MSE measurement. Check dataset/calib_size configuration." + ) + print( + f"[MSE hold-out] Overlap check passed: " + f"{len(calib_hashes)} calib vs {len(holdout_hashes)} hold-out, 0 overlap." + ) + + def auto_quantize( args: argparse.Namespace, language_model: torch.nn.Module, @@ -692,11 +778,17 @@ def export_quantized( if mtp_layer_prefixes: full_model._mtp_layer_prefixes = mtp_layer_prefixes - export_hf_checkpoint( - full_model, - export_dir=export_path, - extra_state_dict=mtp_state_dict, - ) + if args.vllm_fakequant_export: + export_hf_vllm_fq_checkpoint( + full_model, + export_dir=export_path, + ) + else: + export_hf_checkpoint( + full_model, + export_dir=export_path, + extra_state_dict=mtp_state_dict, + ) # Restore default padding and export the tokenizer as well. if tokenizer is not None: @@ -998,7 +1090,7 @@ def quantize_main( mse_save_dir = getattr(args, "activation_mse_save_dir", None) mse_input_path = getattr(args, "activation_mse_input_path", None) - # Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader + # Resolve MSE input data: frozen file (raw text or tokenized), or hold-out set mse_data = None if mse_input_path is not None: if mse_input_path.endswith(".json"): @@ -1014,9 +1106,13 @@ def quantize_main( assert tokenizer is not None, ( "--activation_mse_input_path with .json requires a tokenizer to decode" ) - print(f"Creating MSE input data .json file: {mse_input_path}") + print( + f"Creating MSE input data .json file from hold-out set: {mse_input_path}" + ) + holdout_dl = make_mse_holdout_dataloader(args, tokenizer, device) + verify_no_overlap(calib_dataloader, holdout_dl) texts = ActivationMSELogger.materialize_raw_text( - calib_dataloader, + holdout_dl, mse_input_path, tokenizer=tokenizer, max_samples=n_mse, @@ -1030,10 +1126,15 @@ def quantize_main( if os.path.isfile(mse_input_path): print(f"Loading MSE input data from existing .pt file: {mse_input_path}") mse_data = ActivationMSELogger.load_data(mse_input_path) + verify_no_overlap(calib_dataloader, mse_data) else: - print(f"Creating MSE input data .pt file: {mse_input_path}") + print( + f"Creating MSE input data .pt file from hold-out set: {mse_input_path}" + ) + holdout_dl = make_mse_holdout_dataloader(args, tokenizer, device) + verify_no_overlap(calib_dataloader, holdout_dl) mse_data = ActivationMSELogger.materialize_data( - calib_dataloader, + holdout_dl, mse_input_path, max_samples=n_mse, ) @@ -1043,7 +1144,11 @@ def quantize_main( ) if mse_data is None: - mse_data = calib_dataloader + # Default: create a hold-out set from the same dataset, skipping + # the calibration region to avoid overlap. + print("Creating MSE hold-out dataloader (non-overlapping with calibration)...") + mse_data = make_mse_holdout_dataloader(args, tokenizer, device) + verify_no_overlap(calib_dataloader, mse_data) mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") @@ -1104,12 +1209,16 @@ def quantize_main( # Plugin-registered configs (e.g. PSX LUTS from modelopt-internal) are not exportable # via the standard TRT-LLM / HF export paths. Fall back to save_pretrained(). if args.qformat not in QUANT_CFG_CHOICES and hasattr(mtq, args.qformat): - print( - f"qformat '{args.qformat}' is a plugin-registered config and is not exportable " - f"via the standard export pipeline. Saving with save_pretrained() instead." - ) export_path = args.export_path - full_model.save_pretrained(export_path) + if args.vllm_fakequant_export: + print(f"Exporting vLLM fakequant checkpoint (bf16 weights + amax) to: {export_path}") + export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path) + else: + print( + f"qformat '{args.qformat}' is a plugin-registered config and is not exportable " + f"via the standard export pipeline. Saving with save_pretrained() instead." + ) + full_model.save_pretrained(export_path) if tokenizer is not None: tokenizer.save_pretrained(export_path) print(f"Quantized model saved to: {export_path}") @@ -1176,6 +1285,16 @@ def parse_args() -> argparse.Namespace: default=512, ) parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--vllm_fakequant_export", + action="store_true", + default=False, + help=( + "Export bf16 weights and amax values separately for vLLM fakequant serving. " + "Produces a standard HF checkpoint with GPTQ-adjusted weights plus a " + "quant_amax.pth file that can be loaded via AMAX_FILE_PATH in vllm_serve_fakequant.py." + ), + ) parser.add_argument( "--dataset", help=( diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe669..b717b60da5 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -152,6 +152,7 @@ def disable_compilation(model): "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), + "skip_fold_weight": os.environ.get("SKIP_FOLD_WEIGHT", "0") == "1", } @@ -329,7 +330,14 @@ def calibrate_loop(model: Any = None) -> None: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) - mtq.fold_weight(model) + if quant_config["skip_fold_weight"]: + print("Skipping fold_weight (weights already quantized, e.g. from GPTQ export)") + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + module.disable() + else: + mtq.fold_weight(model) + for name, module in model.named_modules(): if name.endswith("weight_quantizer"): assert not module.is_enabled, f"quantizer {name} is still enabled" diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index 25483f2be1..fb14564568 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -76,6 +76,7 @@ "QUANT_CFG", "AMAX_FILE_PATH", "KV_QUANT_CFG", + "SKIP_FOLD_WEIGHT", } RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index be11108b28..1e74eb79b2 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1334,6 +1334,15 @@ class GPTQConfig(QuantizeAlgorithmConfig): description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) + skip_layers: list[int] = ModeloptField( + default=[], + title="Decoder layer indices to skip GPTQ weight update.", + description=( + "List of decoder layer indices (0-based) for which GPTQ weight update is skipped. " + "These layers still receive max calibration (QDQ amax) but no Hessian-based weight " + "adjustment. Only effective with use_sequential=True." + ), + ) QuantizeQuantCfgType = dict[ diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0752c7b41a..8fe4b72a3b 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1632,6 +1632,7 @@ def gptq( forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, + skip_layers: list[int] | None = None, ): """GPTQ quantization. @@ -1886,6 +1887,7 @@ def _blockwise_update_psx_luts(self, block_size, quantizer): encode_path += "/" quant_block_size = extra_args.get("block_sizes", 16) scale_type = extra_args.get("scale_type", "e4m3") + print(f"[GPTQ psx_luts] quant_block_size={quant_block_size}, scale_type={scale_type}") # Load the vector LUT codebook import luts @@ -1900,6 +1902,7 @@ def _blockwise_update_psx_luts(self, block_size, quantizer): values = values.to(torch.float) vector_size = values.shape[1] + print(f"[GPTQ psx_luts] vector_size={vector_size}, codebook_shape={values.shape}") assert self.weight is not None and self.h_inv is not None out_features, num_cols = self.weight.shape @@ -1970,6 +1973,29 @@ def _print_mse_error(self, hessian): max_calibrate(model, forward_loop=forward_loop) _promote_nvfp4_static_quantizers(model) + # Skip GPTQ weight update for specified layers — fold weights via QDQ instead. + layer_idx = getattr(model, "_seq_calib_layer_idx", None) + if skip_layers and layer_idx is not None and layer_idx in skip_layers: + print_rank_0( + f"[Layer {layer_idx}] In skip_layers {skip_layers} → using RTN path (no GPTQ weight update)" + ) + rtn_count = 0 + for name, module in model.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + wq = module.weight_quantizer + with torch.no_grad(): + module.weight.data = wq(module.weight).to(module.weight.dtype) + backend = getattr(wq, "backend", None) + if backend == "psx_luts": + wq.disable() + rtn_count += 1 + print_rank_0(f" [RTN] {name} — QDQ-folded (backend={backend})") + print_rank_0(f"[Layer {layer_idx}] RTN path complete: {rtn_count} layers folded via QDQ") + return + + if layer_idx is not None: + print_rank_0(f"[Layer {layer_idx}] Not in skip_layers → using GPTQ path") + quantized_layers = [ (n, m) for n, m in model.named_modules() diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 29e0d8a882..cfd3e2c6bc 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -218,6 +218,7 @@ def get_dataset_samples( apply_chat_template: bool = False, tokenizer: "PreTrainedTokenizerBase | None" = None, split: str | list[str] | None = None, + skip_samples: int = 0, ) -> list[str]: """Load a portion of a dataset with the dataset name and a given size. @@ -240,6 +241,9 @@ def get_dataset_samples( split: Override the split(s) to load. Accepts a single split name or a list. If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for registered datasets, or ``["train"]`` for unregistered datasets. + skip_samples: Number of samples to skip from the beginning of each split + before collecting. The skip count is applied **per split**, so + ``skip_samples=128`` with 4 splits skips 128 from each split. Returns: Samples: The list of samples. @@ -306,12 +310,16 @@ def _preprocess(sample: dict) -> str: samples: list[str] = [] for dataset, n in zip(dataset_splits, num_per_split): + collected = 0 for i, sample in enumerate(dataset): - if i >= n: + if i < skip_samples: + continue + if collected >= n: break text = _preprocess(sample) if text: samples.append(text) + collected += 1 return samples @@ -340,6 +348,7 @@ def get_dataset_dataloader( device: torch.device | None = None, include_labels: bool = False, apply_chat_template: bool = False, + skip_samples: int = 0, ) -> DataLoader: """Get a dataloader with the dataset name and tokenizer of the target model. @@ -354,6 +363,8 @@ def get_dataset_dataloader( include_labels: Whether to include labels in the dataloader. apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). + skip_samples: Number of samples to skip per split before collecting + (see :func:`get_dataset_samples`). Returns: An instance of dataloader. @@ -380,7 +391,11 @@ def get_dataset_dataloader( all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): samples = get_dataset_samples( - ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer + ds_name, + num_sample, + apply_chat_template=apply_chat_template, + tokenizer=tokenizer, + skip_samples=skip_samples, ) all_samples.extend(samples) From af59a5531b6fcf4d8454ad27fda18a7c0b8aec9e Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 6 Apr 2026 03:37:30 +0000 Subject: [PATCH 44/48] clean up Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 177 ++++----------------- modelopt/torch/quantization/config.py | 9 -- modelopt/torch/quantization/model_calib.py | 51 ++---- 3 files changed, 41 insertions(+), 196 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ca6bfbdfb4..90a5dd68ea 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,7 +15,6 @@ import argparse import copy -import os import random import time import warnings @@ -248,47 +247,28 @@ def make_calib_dataloader( return calib_dataloader, first_text_speech_dataset -def make_mse_holdout_dataloader( - args: argparse.Namespace, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, -) -> DataLoader: - """Create a hold-out dataloader for activation MSE from the same dataset as calibration. - - Samples are drawn from the same dataset/splits but starting *after* the calibration - region, so there is zero overlap. The skip count per split equals - ``calib_size // num_splits`` (matching how ``get_dataset_samples`` divides samples). - """ +def _make_mse_holdout_dataloader(args, tokenizer, device): + """Create a hold-out dataloader for activation MSE, skipping calibration samples.""" from modelopt.torch.utils.dataset_utils import SUPPORTED_DATASET_CONFIG - dataset_names = args.dataset - calib_sizes = args.calib_size - n_mse = getattr(args, "activation_mse_max_samples", 16) + dataset_names, calib_sizes = args.dataset, args.calib_size + n_mse = args.activation_mse_max_samples - # Compute per-split skip: calib samples per dataset / number of splits for that dataset + # Per-split skip = calib samples / number of splits for that dataset skip_per_dataset = [] for ds_name, cs in zip(dataset_names, calib_sizes): - if ds_name in SUPPORTED_DATASET_CONFIG: - n_splits = len(SUPPORTED_DATASET_CONFIG[ds_name]["config"].get("split", [None])) - else: - n_splits = 1 + n_splits = len( + SUPPORTED_DATASET_CONFIG.get(ds_name, {}).get("config", {}).get("split", [None]) + ) skip_per_dataset.append(cs // max(n_splits, 1)) - - # Use the max skip across datasets (all datasets share the same skip_samples param) skip = max(skip_per_dataset) - # Number of hold-out samples per dataset, proportional to calib_size + # Distribute hold-out samples proportionally across datasets total_calib = sum(calib_sizes) holdout_sizes = [max(1, int(n_mse * cs / total_calib)) for cs in calib_sizes] - # Ensure we get exactly n_mse total holdout_sizes[-1] = n_mse - sum(holdout_sizes[:-1]) - print( - f"Creating MSE hold-out dataloader: skip_per_split={skip}, " - f"holdout_sizes={dict(zip(dataset_names, holdout_sizes))}" - ) - - holdout_dataloader = get_dataset_dataloader( + return get_dataset_dataloader( dataset_name=dataset_names, tokenizer=tokenizer, batch_size=args.batch_size, @@ -296,41 +276,6 @@ def make_mse_holdout_dataloader( device=device, skip_samples=skip, ) - return holdout_dataloader - - -def verify_no_overlap( - calib_dataloader: DataLoader, - holdout_dataloader: DataLoader, -) -> None: - """Verify that calibration and hold-out dataloaders have no overlapping samples. - - Compares SHA-256 hashes of each row of input_ids across both dataloaders. - Raises AssertionError if any overlap is found. - """ - import hashlib - - def _collect_hashes(dataloader: DataLoader) -> set[str]: - hashes = set() - for batch in dataloader: - ids = batch["input_ids"] if isinstance(batch, dict) else batch - for row in ids: - h = hashlib.sha256(row.cpu().numpy().tobytes()).hexdigest() - hashes.add(h) - return hashes - - calib_hashes = _collect_hashes(calib_dataloader) - holdout_hashes = _collect_hashes(holdout_dataloader) - overlap = calib_hashes & holdout_hashes - - assert len(overlap) == 0, ( - f"Found {len(overlap)} overlapping samples between calibration and MSE hold-out data! " - f"This invalidates the MSE measurement. Check dataset/calib_size configuration." - ) - print( - f"[MSE hold-out] Overlap check passed: " - f"{len(calib_hashes)} calib vs {len(holdout_hashes)} hold-out, 0 overlap." - ) def auto_quantize( @@ -1085,73 +1030,19 @@ def quantize_main( _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) # Collect original (unquantized) activations before quantization modifies the model - if getattr(args, "measure_activation_mse", False): - n_mse = getattr(args, "activation_mse_max_samples", 16) - mse_save_dir = getattr(args, "activation_mse_save_dir", None) - mse_input_path = getattr(args, "activation_mse_input_path", None) - - # Resolve MSE input data: frozen file (raw text or tokenized), or hold-out set - mse_data = None - if mse_input_path is not None: - if mse_input_path.endswith(".json"): - if os.path.isfile(mse_input_path): - print(f"Loading MSE input data from existing .json file: {mse_input_path}") - texts = ActivationMSELogger.load_raw_text(mse_input_path) - mse_data = ActivationMSELogger.tokenize_raw_text( - texts, - tokenizer, - max_length=args.calib_seq, - ) - else: - assert tokenizer is not None, ( - "--activation_mse_input_path with .json requires a tokenizer to decode" - ) - print( - f"Creating MSE input data .json file from hold-out set: {mse_input_path}" - ) - holdout_dl = make_mse_holdout_dataloader(args, tokenizer, device) - verify_no_overlap(calib_dataloader, holdout_dl) - texts = ActivationMSELogger.materialize_raw_text( - holdout_dl, - mse_input_path, - tokenizer=tokenizer, - max_samples=n_mse, - ) - mse_data = ActivationMSELogger.tokenize_raw_text( - texts, - tokenizer, - max_length=args.calib_seq, - ) - elif mse_input_path.endswith(".pt"): - if os.path.isfile(mse_input_path): - print(f"Loading MSE input data from existing .pt file: {mse_input_path}") - mse_data = ActivationMSELogger.load_data(mse_input_path) - verify_no_overlap(calib_dataloader, mse_data) - else: - print( - f"Creating MSE input data .pt file from hold-out set: {mse_input_path}" - ) - holdout_dl = make_mse_holdout_dataloader(args, tokenizer, device) - verify_no_overlap(calib_dataloader, holdout_dl) - mse_data = ActivationMSELogger.materialize_data( - holdout_dl, - mse_input_path, - max_samples=n_mse, - ) - else: - raise ValueError( - f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}" - ) - - if mse_data is None: - # Default: create a hold-out set from the same dataset, skipping - # the calibration region to avoid overlap. - print("Creating MSE hold-out dataloader (non-overlapping with calibration)...") - mse_data = make_mse_holdout_dataloader(args, tokenizer, device) - verify_no_overlap(calib_dataloader, mse_data) - - mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) - print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") + if args.measure_activation_mse: + mse_logger = ActivationMSELogger( + max_samples=args.activation_mse_max_samples, + save_dir=args.activation_mse_save_dir, + ) + mse_data = ActivationMSELogger.resolve_data( + input_path=args.activation_mse_input_path, + calib_dataloader=calib_dataloader, + tokenizer=tokenizer, + max_samples=args.activation_mse_max_samples, + max_length=args.calib_seq, + make_holdout_fn=lambda: _make_mse_holdout_dataloader(args, tokenizer, device), + ) mse_logger.collect(language_model, mse_data, phase="original") if args.qformat in QUANT_CFG_CHOICES or hasattr(mtq, args.qformat): @@ -1182,29 +1073,17 @@ def quantize_main( ) if mse_logger is not None: - import gc - - print("Collecting quantized activations for MSE...") - mse_logger.collect(language_model, mse_data, phase="quantized") - - mse_logger.compute_mse() - print(mse_logger.summary()) - - if getattr(args, "activation_mse_save_dir", None): - mse_logger.save() - + mse_logger.finish(language_model, mse_data) del mse_logger, mse_data - gc.collect() torch.cuda.empty_cache() - if getattr(args, "eval_perplexity", False) and tokenizer is not None: - if getattr(args, "fold_weights", False): + if args.eval_perplexity and tokenizer is not None: + if args.fold_weights: print("Folding weights before perplexity evaluation...") mtq.fold_weight(language_model) - seq_len = getattr(args, "eval_perplexity_seq_len", 2048) - eval_data = get_wikitext2(tokenizer, seq_len) + eval_data = get_wikitext2(tokenizer, args.eval_perplexity_seq_len) ppl = compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + print(f"Wikitext-2 perplexity: {ppl:.2f}") # Plugin-registered configs (e.g. PSX LUTS from modelopt-internal) are not exportable # via the standard TRT-LLM / HF export paths. Fall back to save_pretrained(). diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 1e74eb79b2..be11108b28 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1334,15 +1334,6 @@ class GPTQConfig(QuantizeAlgorithmConfig): description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) - skip_layers: list[int] = ModeloptField( - default=[], - title="Decoder layer indices to skip GPTQ weight update.", - description=( - "List of decoder layer indices (0-based) for which GPTQ weight update is skipped. " - "These layers still receive max calibration (QDQ amax) but no Hessian-based weight " - "adjustment. Only effective with use_sequential=True." - ), - ) QuantizeQuantCfgType = dict[ diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8fe4b72a3b..cb95a16ea2 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1609,7 +1609,7 @@ def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: Returns the number of quantizers converted. """ converted = 0 - for _name, module in list(model.named_modules()): + for module in model.modules(): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): is_nvfp4_static = ( @@ -1632,7 +1632,6 @@ def gptq( forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, - skip_layers: list[int] | None = None, ): """GPTQ quantization. @@ -1696,16 +1695,11 @@ def setup(self): def hessian_forward(self, input, *args, **kwargs): inp = input.to_local() if hasattr(input, "to_local") else input if self.input_quantizer is not None and self.input_quantizer.is_enabled: - hessian_input = self.input_quantizer(inp) - else: - hessian_input = inp + inp = self.input_quantizer(inp) gptq_helper.hessian, gptq_helper.n_samples = update_hessian( - hessian_input, gptq_helper.hessian, gptq_helper.n_samples + inp, gptq_helper.hessian, gptq_helper.n_samples ) - - out = self._forward_no_gptq_hessian(input, *args, **kwargs) - - return out + return self._forward_no_gptq_hessian(input, *args, **kwargs) bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) @@ -1879,7 +1873,7 @@ def _blockwise_update_psx_luts(self, block_size, quantizer): Follows the 3-loop structure from the VQ GPTQ reference (adaptive_rounding.py: gptq_quantize_scaled_vq). """ - print_rank_0(f" [{self.name}] Using PSX LUTS GPTQ path (v2)") + print_rank_0(f" [{self.name}] Using PSX LUTS GPTQ path") extra_args = quantizer.backend_extra_args encode_format = quantizer.num_bits encode_path = extra_args.get("encode_path", "") @@ -1887,7 +1881,9 @@ def _blockwise_update_psx_luts(self, block_size, quantizer): encode_path += "/" quant_block_size = extra_args.get("block_sizes", 16) scale_type = extra_args.get("scale_type", "e4m3") - print(f"[GPTQ psx_luts] quant_block_size={quant_block_size}, scale_type={scale_type}") + print_rank_0( + f"[GPTQ psx_luts] quant_block_size={quant_block_size}, scale_type={scale_type}" + ) # Load the vector LUT codebook import luts @@ -1902,7 +1898,9 @@ def _blockwise_update_psx_luts(self, block_size, quantizer): values = values.to(torch.float) vector_size = values.shape[1] - print(f"[GPTQ psx_luts] vector_size={vector_size}, codebook_shape={values.shape}") + print_rank_0( + f"[GPTQ psx_luts] vector_size={vector_size}, codebook_shape={values.shape}" + ) assert self.weight is not None and self.h_inv is not None out_features, num_cols = self.weight.shape @@ -1942,7 +1940,7 @@ def _blockwise_update_psx_luts(self, block_size, quantizer): if d == vector_size: q_sub = self._static_blockwise_vector_quantization(sub_vec, values, s) else: - padded = torch.nn.functional.pad(sub_vec, (0, vector_size - d)) + padded = F.pad(sub_vec, (0, vector_size - d)) q_sub = self._static_blockwise_vector_quantization(padded, values, s)[:, :d] q[:, j : j + d] = q_sub @@ -1962,7 +1960,7 @@ def _print_mse_error(self, hessian): """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" w_orig = self.module.weight.float() delta = self.weight - w_orig - mse = (delta).mm(hessian).mul(delta).mean() / ( + mse = delta.mm(hessian).mul(delta).mean() / ( w_orig.mm(hessian).mul(w_orig).mean() + 1e-6 ) suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" @@ -1973,29 +1971,6 @@ def _print_mse_error(self, hessian): max_calibrate(model, forward_loop=forward_loop) _promote_nvfp4_static_quantizers(model) - # Skip GPTQ weight update for specified layers — fold weights via QDQ instead. - layer_idx = getattr(model, "_seq_calib_layer_idx", None) - if skip_layers and layer_idx is not None and layer_idx in skip_layers: - print_rank_0( - f"[Layer {layer_idx}] In skip_layers {skip_layers} → using RTN path (no GPTQ weight update)" - ) - rtn_count = 0 - for name, module in model.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - wq = module.weight_quantizer - with torch.no_grad(): - module.weight.data = wq(module.weight).to(module.weight.dtype) - backend = getattr(wq, "backend", None) - if backend == "psx_luts": - wq.disable() - rtn_count += 1 - print_rank_0(f" [RTN] {name} — QDQ-folded (backend={backend})") - print_rank_0(f"[Layer {layer_idx}] RTN path complete: {rtn_count} layers folded via QDQ") - return - - if layer_idx is not None: - print_rank_0(f"[Layer {layer_idx}] Not in skip_layers → using GPTQ path") - quantized_layers = [ (n, m) for n, m in model.named_modules() From d52d614ae5cf4e85d37fa2dd894228c29437c888 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 6 Apr 2026 17:56:51 +0000 Subject: [PATCH 45/48] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- .../utils/activation_collector.py | 129 ++++++++++--- .../torch/quantization/utils/checkpoint.py | 173 ++++++++++++++++++ 2 files changed, 275 insertions(+), 27 deletions(-) create mode 100644 modelopt/torch/quantization/utils/checkpoint.py diff --git a/modelopt/torch/quantization/utils/activation_collector.py b/modelopt/torch/quantization/utils/activation_collector.py index 5f187fdcb2..9b7c4a3406 100644 --- a/modelopt/torch/quantization/utils/activation_collector.py +++ b/modelopt/torch/quantization/utils/activation_collector.py @@ -22,7 +22,7 @@ from collections import deque from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal import torch import torch.nn as nn @@ -44,11 +44,12 @@ class _LayerCalibState: patched forward to decide skip / run / capture / original behaviour. """ - mode: str = "original" + mode: Literal["original", "skip", "run", "capture"] = "original" name: str = "" cached_inputs: deque = field(default_factory=deque) collected_inputs: list = field(default_factory=list) output_meta: tuple | None = None + capture_output_meta: bool = False class LayerActivationCollector: @@ -150,7 +151,10 @@ def _zeros_from_meta(meta): # downstream run-mode layer, which replays from its own cached inputs instead. return meta[1] - def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): + def _patch_all_layers( + self, + decoder_layers: nn.ModuleList | None = None, + ): """Bind the unified forward to every decoder layer and the model. Called once. Args: @@ -188,7 +192,10 @@ def _patched_forward(self, *args, **kwargs): info.collected_inputs.append((args, kwargs)) raise _EarlyStopForwardError() - return self._original_forward(*args, **kwargs) + output = self._original_forward(*args, **kwargs) + if info.capture_output_meta: + info.output_meta = LayerActivationCollector._extract_output_meta(output) + return output if decoder_layers is not None: self._decoder_layers = decoder_layers @@ -200,7 +207,7 @@ def _patched_forward(self, *args, **kwargs): module_to_name = {m: name for name, m in self.model.named_modules()} try: - for layer in self._decoder_layers: + for i, layer in enumerate(self._decoder_layers): layer._seq_calib = _LayerCalibState( name=module_to_name.get(layer, type(layer).__name__), ) @@ -238,6 +245,26 @@ def _unpatch_all_layers(self): self._cleanup_layers() self._patched = False + def _set_layer_mode( + self, layer_idx: int, mode: Literal["original", "skip", "run", "capture"] + ) -> None: + """Set the mode for a single decoder layer with appropriate side effects.""" + assert self._decoder_layers is not None + state = self._decoder_layers[layer_idx]._seq_calib + state.mode = mode + + if mode == "skip": + state.cached_inputs.clear() + elif mode == "run": + if not state.collected_inputs: + raise RuntimeError( + f"Layer {layer_idx} ({state.name!r}) has no collected inputs to replay." + ) + state.cached_inputs = deque(state.collected_inputs) + state.collected_inputs = [] + elif mode == "capture": + state.collected_inputs = [] + def _set_layer_states(self, layer_idx: int): """Transition layer modes for the next calibration step. @@ -247,30 +274,11 @@ def _set_layer_states(self, layer_idx: int): * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). * Layer ``i`` → **capture** (record inputs, then early-stop). """ - assert self._decoder_layers is not None - if layer_idx > 1: - done = self._decoder_layers[layer_idx - 2]._seq_calib - # output_meta is intentionally kept: skip mode needs it to produce - # correctly shaped zero-filled outputs for the parent forward. - done.mode = "skip" - done.cached_inputs.clear() - + self._set_layer_mode(layer_idx - 2, "skip") if layer_idx > 0: - prev = self._decoder_layers[layer_idx - 1]._seq_calib - if not prev.collected_inputs: - raise RuntimeError( - f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " - "Layers must be calibrated sequentially — ensure get_input_activations() " - "was called for every preceding layer in order." - ) - prev.mode = "run" - prev.cached_inputs = deque(prev.collected_inputs) - prev.collected_inputs = [] - - cur = self._decoder_layers[layer_idx]._seq_calib - cur.mode = "capture" - cur.collected_inputs = [] + self._set_layer_mode(layer_idx - 1, "run") + self._set_layer_mode(layer_idx, "capture") def _log_layer_summary(self, layer_idx: int): """Log a one-line summary of layer modes for the current calibration step.""" @@ -284,6 +292,30 @@ def _log_layer_summary(self, layer_idx: int): parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") + def _run_warmup_capture(self, capture_layer_idx: int, forward_loop: ForwardLoop) -> None: + """Run a forward pass with *capture_layer_idx* in capture mode. + + Raises RuntimeError if no inputs are collected. + """ + assert self._decoder_layers is not None + state = self._decoder_layers[capture_layer_idx]._seq_calib + state.mode = "capture" + state.collected_inputs = [] + + try: + forward_loop(self.model) + except Exception: + state.mode = "original" + state.collected_inputs = [] + raise + + if not state.collected_inputs: + state.mode = "original" + raise RuntimeError( + f"Warm-up forward collected no inputs for layer {capture_layer_idx}. " + "Cannot resume sequential calibration." + ) + # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ @@ -333,3 +365,46 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo # in subsequent iterations via _set_layer_states. info.mode = "original" return inputs + + @torch.no_grad() + def prepare_for_resume( + self, + resume_layer_idx: int, + forward_loop: ForwardLoop, + ): + """Set up layer states for resuming sequential calibration from a checkpoint. + + Runs a single warm-up forward pass so that the next call to + :meth:`get_input_activations` for ``resume_layer_idx`` produces the + correct inputs. Layers ``0 .. K-2`` run in *original* mode (with + ``capture_output_meta`` enabled so skip-mode metadata is populated), + layer ``K-1`` in *capture* mode. After the pass, ``0 .. K-2`` switch + to *skip* and ``K-1`` retains its ``collected_inputs`` for the + subsequent *run* transition. + """ + if not self._patched: + raise RuntimeError( + "prepare_for_resume() requires _patch_all_layers() to be called first." + ) + if resume_layer_idx == 0: + return + + k = resume_layer_idx + preceding = range(k - 1) + + assert self._decoder_layers is not None + for i in preceding: + self._set_layer_mode(i, "original") + self._decoder_layers[i]._seq_calib.capture_output_meta = True + + print_rank_0( + f"Running warm-up forward pass for resume " + f"(layers 0..{k - 2} original, layer {k - 1} capture)" + ) + self._run_warmup_capture(k - 1, forward_loop) + + for i in preceding: + self._decoder_layers[i]._seq_calib.capture_output_meta = False + self._set_layer_mode(i, "skip") + + print_rank_0(f"Warm-up complete. Ready to resume from layer {k}.") diff --git a/modelopt/torch/quantization/utils/checkpoint.py b/modelopt/torch/quantization/utils/checkpoint.py new file mode 100644 index 0000000000..aa4fd4e538 --- /dev/null +++ b/modelopt/torch/quantization/utils/checkpoint.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint save/resume utilities for sequential calibration. + +Provides: + +* A pluggable **save registry** — plugins (e.g. huggingface.py) register a + ``(predicate, save_fn)`` pair at import time so that + :func:`get_checkpoint_saver` can find the right saver for any model. + +* **Resume detection** — :func:`detect_sequential_resume_layer` reads progress + metadata previously attached to the model and returns the layer index to + resume from. + +* **Checkpoint saving** — :func:`save_sequential_checkpoint` attaches progress + to the model and delegates to the registered saver. +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +from modelopt.torch.utils import print_rank_0 + +if TYPE_CHECKING: + from collections.abc import Callable + + import torch.nn as nn + +#: Model attribute name used to store sequential calibration progress. +SEQ_CALIB_PROGRESS_ATTR = "_seq_calib_progress" + +# --------------------------------------------------------------------------- +# Save registry +# --------------------------------------------------------------------------- +_CHECKPOINT_SAVE_SUPPORT: list[ + tuple[Callable[[nn.Module], bool], Callable[[nn.Module, str], None]] +] = [] + + +def register_seq_calib_checkpoint_saver( + is_supported: Callable[[nn.Module], bool], + save_fn: Callable[[nn.Module, str], None], +) -> None: + """Register a ``(predicate, saver)`` pair for sequential calibration checkpointing.""" + entry = (is_supported, save_fn) + if entry not in _CHECKPOINT_SAVE_SUPPORT: + _CHECKPOINT_SAVE_SUPPORT.append(entry) + + +def get_checkpoint_saver( + model: nn.Module, +) -> Callable[[nn.Module, str], None] | None: + """Return the registered save function for *model*, or *None*.""" + for is_supported, save_fn in _CHECKPOINT_SAVE_SUPPORT: + if is_supported(model): + return save_fn + return None + + +def detect_sequential_resume_layer(model: nn.Module, num_layers: int) -> int: + """Read checkpoint progress from the model and return the layer index to resume from. + + Returns ``0`` for a fresh run with no checkpoint present. + The attribute is **not** deleted here — cleanup is owned by + :func:`sequential_calibrate`'s ``finally`` block. + """ + progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None) + if progress is None: + return 0 + + if not isinstance(progress, dict): + raise ValueError( + f"Expected seq_calib_progress to be a dict, got {type(progress).__name__}." + ) + for key in ("completed_layer_idx", "total_layers"): + if key not in progress: + raise ValueError(f"Checkpoint progress is missing required key {key!r}.") + + completed_layer = progress["completed_layer_idx"] + saved_total = progress["total_layers"] + + if not isinstance(completed_layer, int) or not isinstance(saved_total, int): + raise ValueError( + f"Checkpoint progress values must be ints, got " + f"completed_layer_idx={completed_layer!r}, total_layers={saved_total!r}." + ) + + if saved_total != num_layers: + raise ValueError( + f"Checkpoint was saved with {saved_total} layers but model has " + f"{num_layers} layers. Cannot resume." + ) + + if not (0 <= completed_layer < num_layers): + raise ValueError( + f"completed_layer_idx={completed_layer} is out of range for " + f"{num_layers} layers (expected 0..{num_layers - 1})." + ) + + resume_from = completed_layer + 1 + print_rank_0( + f"Resuming sequential calibration from layer {resume_from} " + f"(layers 0..{completed_layer} already calibrated)" + ) + return resume_from + + +def should_save_seq_calib_checkpoint( + layer_idx: int, num_layers: int, checkpoint_dir: str | None, checkpoint_interval: int | None +) -> bool: + """Return *True* when a checkpoint should be saved after calibrating *layer_idx*.""" + if checkpoint_interval is not None and checkpoint_interval <= 0: + raise ValueError( + f"checkpoint_interval must be a positive integer, got {checkpoint_interval}." + ) + return ( + checkpoint_dir is not None + and checkpoint_interval is not None + and (layer_idx + 1) % checkpoint_interval == 0 + and layer_idx < num_layers - 1 # never save after the final layer + ) + + +def save_sequential_checkpoint( + model: nn.Module, + completed_layer_idx: int, + total_layers: int, + checkpoint_dir: str, +) -> None: + """Save a rolling checkpoint during sequential calibration. + + Temporarily attaches progress to the model so that ``update_quantize_metadata`` + can serialize it during ``save_pretrained``. The attribute is **not** deleted + here — cleanup is owned by :func:`sequential_calibrate`'s ``finally`` block. + """ + saver = get_checkpoint_saver(model) + if saver is None: + print_rank_0( + "Warning: checkpoint_dir is set but no checkpoint saver is registered " + "for this model type. Skipping checkpoint save." + ) + return + + setattr( + model, + SEQ_CALIB_PROGRESS_ATTR, + { + "completed_layer_idx": completed_layer_idx, + "total_layers": total_layers, + }, + ) + + os.makedirs(checkpoint_dir, exist_ok=True) + saver(model, checkpoint_dir) + print_rank_0( + f"Saved sequential calibration checkpoint at layer " + f"{completed_layer_idx + 1}/{total_layers} to {checkpoint_dir}" + ) From c705c2436dd32a609984d6297c557086785e1976 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 6 Apr 2026 20:41:24 +0000 Subject: [PATCH 46/48] updated e2e test Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- tests/gpu/torch/quantization/test_gptq_vq.py | 96 ++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 tests/gpu/torch/quantization/test_gptq_vq.py diff --git a/tests/gpu/torch/quantization/test_gptq_vq.py b/tests/gpu/torch/quantization/test_gptq_vq.py new file mode 100644 index 0000000000..8785c95c9d --- /dev/null +++ b/tests/gpu/torch/quantization/test_gptq_vq.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test ModelOpt GPTQ with PSX LUTS VQ on a single expert linear layer.""" + +import copy + +import pytest +import torch + +RTN_CFG_NAME = ( + "PSX_LUTS_WEIGHT_VL_VS8_Entries65536_LFSR_max_sorted_bs16_ACTIVATION_NONE_CFG_routed_moes" +) +GPTQ_CFG_NAME = ( + "GPTQ_PSX_LUTS_WEIGHT_VL_VS8_Entries65536_LFSR_max_sorted_bs16_ACTIVATION_NONE_CFG_routed_moes" +) + + +def _configs_available(): + try: + import modelopt.torch.quantization as mtq + + return getattr(mtq, RTN_CFG_NAME, None) is not None + except Exception: + return False + + +class _SingleExpertModel(torch.nn.Module): + """Wraps a Linear so its path contains 'experts' to match the quant_cfg patterns.""" + + def __init__(self, in_features, out_features): + super().__init__() + self.experts = torch.nn.ModuleList([torch.nn.Linear(in_features, out_features, bias=False)]) + + def forward(self, x): + return self.experts[0](x) + + +@pytest.mark.skipif(not _configs_available(), reason="PSX LUTS plugin configs not available") +def test_modelopt_gptq_vs_rtn(): + """GPTQ should produce lower output NMSE than RTN on a single expert layer.""" + import modelopt.torch.quantization as mtq + + rtn_cfg = copy.deepcopy(getattr(mtq, RTN_CFG_NAME)) + gptq_cfg = copy.deepcopy(getattr(mtq, GPTQ_CFG_NAME)) + # Single-layer model has no decoder layers for sequential calibration + gptq_cfg["algorithm"]["use_sequential"] = False + + torch.manual_seed(42) + out_features, in_features, n_samples = 64, 256, 128 + + model = _SingleExpertModel(in_features, out_features).cuda().float() + orig_weight = model.experts[0].weight.data.clone() + calib_data = [torch.randn(1, in_features, device="cuda") for _ in range(n_samples)] + + def forward_loop(m): + for x in calib_data: + m(x) + + # RTN (fold weights so we get the actual QDQ'd values) + rtn_model = mtq.quantize(copy.deepcopy(model), rtn_cfg, forward_loop=forward_loop) + mtq.fold_weight(rtn_model) + rtn_weight = rtn_model.experts[0].weight.data.float() + + # GPTQ + gptq_model = mtq.quantize(copy.deepcopy(model), gptq_cfg, forward_loop=forward_loop) + gptq_weight = gptq_model.experts[0].weight.data.float() + + # Output NMSE + act = torch.cat(calib_data, dim=0).squeeze().T # (in_features, n_samples) + w = orig_weight.float() + ref_norm_sq = (w @ act).norm() ** 2 + nmse_rtn = ((rtn_weight - w) @ act).norm() ** 2 / ref_norm_sq + nmse_gptq = ((gptq_weight - w) @ act).norm() ** 2 / ref_norm_sq + + print(f"\nRTN NMSE: {nmse_rtn:.8f}") + print(f"GPTQ NMSE: {nmse_gptq:.8f}") + print(f"GPTQ gain over RTN: {nmse_rtn / nmse_gptq:.4f}x") + + assert nmse_gptq < nmse_rtn, "GPTQ should beat RTN" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 6d8f9b310c2ec8baa92c2f66ba0977e3cf1e0ec8 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 19:16:27 +0000 Subject: [PATCH 47/48] new perplexity eval Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 90a5dd68ea..84fe2666f8 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1081,9 +1081,30 @@ def quantize_main( if args.fold_weights: print("Folding weights before perplexity evaluation...") mtq.fold_weight(language_model) - eval_data = get_wikitext2(tokenizer, args.eval_perplexity_seq_len) + if args.eval_perplexity_input_path: + print(f"Loading perplexity eval data from {args.eval_perplexity_input_path}") + eval_data = torch.load( + args.eval_perplexity_input_path, map_location="cpu", weights_only=True + ) + # Unbatch to [1, seq_len] per element for compute_perplexity batching + unbatched = [] + for t in eval_data: + if not isinstance(t, torch.Tensor): + continue + if t.dim() == 1: + unbatched.append(t.unsqueeze(0)) + elif t.shape[0] > 1: + unbatched.extend(t.unbind(0)) + else: + unbatched.append(t) + eval_data = [t.unsqueeze(0) if t.dim() == 1 else t for t in unbatched] + print(f"Loaded {len(eval_data)} sequences from {args.eval_perplexity_input_path}") + label = args.eval_perplexity_input_path + else: + eval_data = get_wikitext2(tokenizer, args.eval_perplexity_seq_len) + label = "Wikitext-2" ppl = compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {ppl:.2f}") + print(f"{label} perplexity: {ppl:.2f}") # Plugin-registered configs (e.g. PSX LUTS from modelopt-internal) are not exportable # via the standard TRT-LLM / HF export paths. Fall back to save_pretrained(). @@ -1341,6 +1362,18 @@ def parse_args() -> argparse.Namespace: default=2048, help="Sequence length for perplexity evaluation (default: 2048).", ) + parser.add_argument( + "--eval_perplexity_input_path", + type=str, + default=None, + help=( + "Path to a .pt file containing pre-tokenized evaluation data " + "(List[Tensor], each [1, seq_len]) for perplexity evaluation. " + "When set, this data is used instead of WikiText-2. " + "Compatible with the .pt files produced by create_holdout_mse_inputs.py " + "or --activation_mse_input_path." + ), + ) parser.add_argument( "--measure_activation_mse", action=argparse.BooleanOptionalAction, From 224f77d4efa7cc388ed29ba7628cb411b0faebb6 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 9 Apr 2026 01:02:55 +0000 Subject: [PATCH 48/48] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 63 +++++-- modelopt/torch/quantization/model_calib.py | 2 - .../utils/activation_collector.py | 129 +++---------- .../torch/quantization/utils/checkpoint.py | 173 ------------------ 4 files changed, 79 insertions(+), 288 deletions(-) delete mode 100644 modelopt/torch/quantization/utils/checkpoint.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 84fe2666f8..725a9161e7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -248,33 +248,74 @@ def make_calib_dataloader( def _make_mse_holdout_dataloader(args, tokenizer, device): - """Create a hold-out dataloader for activation MSE, skipping calibration samples.""" - from modelopt.torch.utils.dataset_utils import SUPPORTED_DATASET_CONFIG + """Create a hold-out dataloader for activation MSE, excluding calibration samples. + + Uses content-based exclusion: reconstructs the exact calibration texts and + filters them from the hold-out set. This is robust to empty/filtered samples + that cause pure index-based skipping to under-count and produce overlap. + """ + from modelopt.torch.utils.dataset_utils import SUPPORTED_DATASET_CONFIG, get_dataset_samples dataset_names, calib_sizes = args.dataset, args.calib_size n_mse = args.activation_mse_max_samples - # Per-split skip = calib samples / number of splits for that dataset + # 1. Reconstruct the exact calibration texts (same args as make_calib_dataloader). + calib_texts: set[str] = set() + for ds_name, cs in zip(dataset_names, calib_sizes): + calib_texts.update(get_dataset_samples(ds_name, cs, tokenizer=tokenizer)) + + # 2. Per-split skip for efficiency (avoids re-iterating calibration range). skip_per_dataset = [] for ds_name, cs in zip(dataset_names, calib_sizes): n_splits = len( SUPPORTED_DATASET_CONFIG.get(ds_name, {}).get("config", {}).get("split", [None]) ) - skip_per_dataset.append(cs // max(n_splits, 1)) + skip_per_dataset.append(-(-cs // max(n_splits, 1))) skip = max(skip_per_dataset) - # Distribute hold-out samples proportionally across datasets + # 3. Distribute hold-out samples proportionally across datasets. total_calib = sum(calib_sizes) holdout_sizes = [max(1, int(n_mse * cs / total_calib)) for cs in calib_sizes] holdout_sizes[-1] = n_mse - sum(holdout_sizes[:-1]) - return get_dataset_dataloader( - dataset_name=dataset_names, - tokenizer=tokenizer, + # 4. Collect hold-out texts, requesting extras to replace any filtered overlaps. + all_holdout_texts: list[str] = [] + for ds_name, hs in zip(dataset_names, holdout_sizes): + texts = get_dataset_samples( + ds_name, + hs + len(calib_texts), + tokenizer=tokenizer, + skip_samples=skip, + ) + filtered = [t for t in texts if t not in calib_texts][:hs] + all_holdout_texts.extend(filtered) + + # 5. Tokenize and build dataloader (mirrors get_dataset_dataloader logic). + tok = copy.deepcopy(tokenizer) + batch_encoded = tok( + all_holdout_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ) + if device: + batch_encoded = batch_encoded.to(device) + + class _InputIdsDataset(torch.utils.data.Dataset): + def __init__(self, input_ids): + self.input_ids = input_ids + + def __getitem__(self, idx): + return {"input_ids": self.input_ids[idx]} + + def __len__(self): + return len(self.input_ids) + + return DataLoader( + _InputIdsDataset(batch_encoded["input_ids"]), batch_size=args.batch_size, - num_samples=holdout_sizes, - device=device, - skip_samples=skip, + shuffle=False, ) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index cb95a16ea2..257f8f4979 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1581,8 +1581,6 @@ def sequential_calibrate( try: for layer_idx, layer in enumerate(transformer_layers): print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") - # Store layer_idx so gptq/GPTQHelper can access it for debugging - layer._seq_calib_layer_idx = layer_idx layer_inputs = input_getter.get_input_activations(layer, forward_loop) def _layer_forward_loop(m, _inputs=layer_inputs): diff --git a/modelopt/torch/quantization/utils/activation_collector.py b/modelopt/torch/quantization/utils/activation_collector.py index 9b7c4a3406..5f187fdcb2 100644 --- a/modelopt/torch/quantization/utils/activation_collector.py +++ b/modelopt/torch/quantization/utils/activation_collector.py @@ -22,7 +22,7 @@ from collections import deque from dataclasses import dataclass, field -from typing import Any, Literal +from typing import Any import torch import torch.nn as nn @@ -44,12 +44,11 @@ class _LayerCalibState: patched forward to decide skip / run / capture / original behaviour. """ - mode: Literal["original", "skip", "run", "capture"] = "original" + mode: str = "original" name: str = "" cached_inputs: deque = field(default_factory=deque) collected_inputs: list = field(default_factory=list) output_meta: tuple | None = None - capture_output_meta: bool = False class LayerActivationCollector: @@ -151,10 +150,7 @@ def _zeros_from_meta(meta): # downstream run-mode layer, which replays from its own cached inputs instead. return meta[1] - def _patch_all_layers( - self, - decoder_layers: nn.ModuleList | None = None, - ): + def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): """Bind the unified forward to every decoder layer and the model. Called once. Args: @@ -192,10 +188,7 @@ def _patched_forward(self, *args, **kwargs): info.collected_inputs.append((args, kwargs)) raise _EarlyStopForwardError() - output = self._original_forward(*args, **kwargs) - if info.capture_output_meta: - info.output_meta = LayerActivationCollector._extract_output_meta(output) - return output + return self._original_forward(*args, **kwargs) if decoder_layers is not None: self._decoder_layers = decoder_layers @@ -207,7 +200,7 @@ def _patched_forward(self, *args, **kwargs): module_to_name = {m: name for name, m in self.model.named_modules()} try: - for i, layer in enumerate(self._decoder_layers): + for layer in self._decoder_layers: layer._seq_calib = _LayerCalibState( name=module_to_name.get(layer, type(layer).__name__), ) @@ -245,26 +238,6 @@ def _unpatch_all_layers(self): self._cleanup_layers() self._patched = False - def _set_layer_mode( - self, layer_idx: int, mode: Literal["original", "skip", "run", "capture"] - ) -> None: - """Set the mode for a single decoder layer with appropriate side effects.""" - assert self._decoder_layers is not None - state = self._decoder_layers[layer_idx]._seq_calib - state.mode = mode - - if mode == "skip": - state.cached_inputs.clear() - elif mode == "run": - if not state.collected_inputs: - raise RuntimeError( - f"Layer {layer_idx} ({state.name!r}) has no collected inputs to replay." - ) - state.cached_inputs = deque(state.collected_inputs) - state.collected_inputs = [] - elif mode == "capture": - state.collected_inputs = [] - def _set_layer_states(self, layer_idx: int): """Transition layer modes for the next calibration step. @@ -274,11 +247,30 @@ def _set_layer_states(self, layer_idx: int): * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). * Layer ``i`` → **capture** (record inputs, then early-stop). """ + assert self._decoder_layers is not None + if layer_idx > 1: - self._set_layer_mode(layer_idx - 2, "skip") + done = self._decoder_layers[layer_idx - 2]._seq_calib + # output_meta is intentionally kept: skip mode needs it to produce + # correctly shaped zero-filled outputs for the parent forward. + done.mode = "skip" + done.cached_inputs.clear() + if layer_idx > 0: - self._set_layer_mode(layer_idx - 1, "run") - self._set_layer_mode(layer_idx, "capture") + prev = self._decoder_layers[layer_idx - 1]._seq_calib + if not prev.collected_inputs: + raise RuntimeError( + f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " + "Layers must be calibrated sequentially — ensure get_input_activations() " + "was called for every preceding layer in order." + ) + prev.mode = "run" + prev.cached_inputs = deque(prev.collected_inputs) + prev.collected_inputs = [] + + cur = self._decoder_layers[layer_idx]._seq_calib + cur.mode = "capture" + cur.collected_inputs = [] def _log_layer_summary(self, layer_idx: int): """Log a one-line summary of layer modes for the current calibration step.""" @@ -292,30 +284,6 @@ def _log_layer_summary(self, layer_idx: int): parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") - def _run_warmup_capture(self, capture_layer_idx: int, forward_loop: ForwardLoop) -> None: - """Run a forward pass with *capture_layer_idx* in capture mode. - - Raises RuntimeError if no inputs are collected. - """ - assert self._decoder_layers is not None - state = self._decoder_layers[capture_layer_idx]._seq_calib - state.mode = "capture" - state.collected_inputs = [] - - try: - forward_loop(self.model) - except Exception: - state.mode = "original" - state.collected_inputs = [] - raise - - if not state.collected_inputs: - state.mode = "original" - raise RuntimeError( - f"Warm-up forward collected no inputs for layer {capture_layer_idx}. " - "Cannot resume sequential calibration." - ) - # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ @@ -365,46 +333,3 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo # in subsequent iterations via _set_layer_states. info.mode = "original" return inputs - - @torch.no_grad() - def prepare_for_resume( - self, - resume_layer_idx: int, - forward_loop: ForwardLoop, - ): - """Set up layer states for resuming sequential calibration from a checkpoint. - - Runs a single warm-up forward pass so that the next call to - :meth:`get_input_activations` for ``resume_layer_idx`` produces the - correct inputs. Layers ``0 .. K-2`` run in *original* mode (with - ``capture_output_meta`` enabled so skip-mode metadata is populated), - layer ``K-1`` in *capture* mode. After the pass, ``0 .. K-2`` switch - to *skip* and ``K-1`` retains its ``collected_inputs`` for the - subsequent *run* transition. - """ - if not self._patched: - raise RuntimeError( - "prepare_for_resume() requires _patch_all_layers() to be called first." - ) - if resume_layer_idx == 0: - return - - k = resume_layer_idx - preceding = range(k - 1) - - assert self._decoder_layers is not None - for i in preceding: - self._set_layer_mode(i, "original") - self._decoder_layers[i]._seq_calib.capture_output_meta = True - - print_rank_0( - f"Running warm-up forward pass for resume " - f"(layers 0..{k - 2} original, layer {k - 1} capture)" - ) - self._run_warmup_capture(k - 1, forward_loop) - - for i in preceding: - self._decoder_layers[i]._seq_calib.capture_output_meta = False - self._set_layer_mode(i, "skip") - - print_rank_0(f"Warm-up complete. Ready to resume from layer {k}.") diff --git a/modelopt/torch/quantization/utils/checkpoint.py b/modelopt/torch/quantization/utils/checkpoint.py deleted file mode 100644 index aa4fd4e538..0000000000 --- a/modelopt/torch/quantization/utils/checkpoint.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Checkpoint save/resume utilities for sequential calibration. - -Provides: - -* A pluggable **save registry** — plugins (e.g. huggingface.py) register a - ``(predicate, save_fn)`` pair at import time so that - :func:`get_checkpoint_saver` can find the right saver for any model. - -* **Resume detection** — :func:`detect_sequential_resume_layer` reads progress - metadata previously attached to the model and returns the layer index to - resume from. - -* **Checkpoint saving** — :func:`save_sequential_checkpoint` attaches progress - to the model and delegates to the registered saver. -""" - -from __future__ import annotations - -import os -from typing import TYPE_CHECKING - -from modelopt.torch.utils import print_rank_0 - -if TYPE_CHECKING: - from collections.abc import Callable - - import torch.nn as nn - -#: Model attribute name used to store sequential calibration progress. -SEQ_CALIB_PROGRESS_ATTR = "_seq_calib_progress" - -# --------------------------------------------------------------------------- -# Save registry -# --------------------------------------------------------------------------- -_CHECKPOINT_SAVE_SUPPORT: list[ - tuple[Callable[[nn.Module], bool], Callable[[nn.Module, str], None]] -] = [] - - -def register_seq_calib_checkpoint_saver( - is_supported: Callable[[nn.Module], bool], - save_fn: Callable[[nn.Module, str], None], -) -> None: - """Register a ``(predicate, saver)`` pair for sequential calibration checkpointing.""" - entry = (is_supported, save_fn) - if entry not in _CHECKPOINT_SAVE_SUPPORT: - _CHECKPOINT_SAVE_SUPPORT.append(entry) - - -def get_checkpoint_saver( - model: nn.Module, -) -> Callable[[nn.Module, str], None] | None: - """Return the registered save function for *model*, or *None*.""" - for is_supported, save_fn in _CHECKPOINT_SAVE_SUPPORT: - if is_supported(model): - return save_fn - return None - - -def detect_sequential_resume_layer(model: nn.Module, num_layers: int) -> int: - """Read checkpoint progress from the model and return the layer index to resume from. - - Returns ``0`` for a fresh run with no checkpoint present. - The attribute is **not** deleted here — cleanup is owned by - :func:`sequential_calibrate`'s ``finally`` block. - """ - progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None) - if progress is None: - return 0 - - if not isinstance(progress, dict): - raise ValueError( - f"Expected seq_calib_progress to be a dict, got {type(progress).__name__}." - ) - for key in ("completed_layer_idx", "total_layers"): - if key not in progress: - raise ValueError(f"Checkpoint progress is missing required key {key!r}.") - - completed_layer = progress["completed_layer_idx"] - saved_total = progress["total_layers"] - - if not isinstance(completed_layer, int) or not isinstance(saved_total, int): - raise ValueError( - f"Checkpoint progress values must be ints, got " - f"completed_layer_idx={completed_layer!r}, total_layers={saved_total!r}." - ) - - if saved_total != num_layers: - raise ValueError( - f"Checkpoint was saved with {saved_total} layers but model has " - f"{num_layers} layers. Cannot resume." - ) - - if not (0 <= completed_layer < num_layers): - raise ValueError( - f"completed_layer_idx={completed_layer} is out of range for " - f"{num_layers} layers (expected 0..{num_layers - 1})." - ) - - resume_from = completed_layer + 1 - print_rank_0( - f"Resuming sequential calibration from layer {resume_from} " - f"(layers 0..{completed_layer} already calibrated)" - ) - return resume_from - - -def should_save_seq_calib_checkpoint( - layer_idx: int, num_layers: int, checkpoint_dir: str | None, checkpoint_interval: int | None -) -> bool: - """Return *True* when a checkpoint should be saved after calibrating *layer_idx*.""" - if checkpoint_interval is not None and checkpoint_interval <= 0: - raise ValueError( - f"checkpoint_interval must be a positive integer, got {checkpoint_interval}." - ) - return ( - checkpoint_dir is not None - and checkpoint_interval is not None - and (layer_idx + 1) % checkpoint_interval == 0 - and layer_idx < num_layers - 1 # never save after the final layer - ) - - -def save_sequential_checkpoint( - model: nn.Module, - completed_layer_idx: int, - total_layers: int, - checkpoint_dir: str, -) -> None: - """Save a rolling checkpoint during sequential calibration. - - Temporarily attaches progress to the model so that ``update_quantize_metadata`` - can serialize it during ``save_pretrained``. The attribute is **not** deleted - here — cleanup is owned by :func:`sequential_calibrate`'s ``finally`` block. - """ - saver = get_checkpoint_saver(model) - if saver is None: - print_rank_0( - "Warning: checkpoint_dir is set but no checkpoint saver is registered " - "for this model type. Skipping checkpoint save." - ) - return - - setattr( - model, - SEQ_CALIB_PROGRESS_ATTR, - { - "completed_layer_idx": completed_layer_idx, - "total_layers": total_layers, - }, - ) - - os.makedirs(checkpoint_dir, exist_ok=True) - saver(model, checkpoint_dir) - print_rank_0( - f"Saved sequential calibration checkpoint at layer " - f"{completed_layer_idx + 1}/{total_layers} to {checkpoint_dir}" - )