diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5620ddf6a4..725a9161e7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -60,7 +60,13 @@ save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model +from modelopt.torch.export.plugins.vllm_fakequant_hf import export_hf_vllm_fq_checkpoint from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.metrics_backup import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.utils.dataset_utils import ( @@ -241,6 +247,78 @@ def make_calib_dataloader( return calib_dataloader, first_text_speech_dataset +def _make_mse_holdout_dataloader(args, tokenizer, device): + """Create a hold-out dataloader for activation MSE, excluding calibration samples. + + Uses content-based exclusion: reconstructs the exact calibration texts and + filters them from the hold-out set. This is robust to empty/filtered samples + that cause pure index-based skipping to under-count and produce overlap. + """ + from modelopt.torch.utils.dataset_utils import SUPPORTED_DATASET_CONFIG, get_dataset_samples + + dataset_names, calib_sizes = args.dataset, args.calib_size + n_mse = args.activation_mse_max_samples + + # 1. Reconstruct the exact calibration texts (same args as make_calib_dataloader). + calib_texts: set[str] = set() + for ds_name, cs in zip(dataset_names, calib_sizes): + calib_texts.update(get_dataset_samples(ds_name, cs, tokenizer=tokenizer)) + + # 2. Per-split skip for efficiency (avoids re-iterating calibration range). + skip_per_dataset = [] + for ds_name, cs in zip(dataset_names, calib_sizes): + n_splits = len( + SUPPORTED_DATASET_CONFIG.get(ds_name, {}).get("config", {}).get("split", [None]) + ) + skip_per_dataset.append(-(-cs // max(n_splits, 1))) + skip = max(skip_per_dataset) + + # 3. Distribute hold-out samples proportionally across datasets. + total_calib = sum(calib_sizes) + holdout_sizes = [max(1, int(n_mse * cs / total_calib)) for cs in calib_sizes] + holdout_sizes[-1] = n_mse - sum(holdout_sizes[:-1]) + + # 4. Collect hold-out texts, requesting extras to replace any filtered overlaps. + all_holdout_texts: list[str] = [] + for ds_name, hs in zip(dataset_names, holdout_sizes): + texts = get_dataset_samples( + ds_name, + hs + len(calib_texts), + tokenizer=tokenizer, + skip_samples=skip, + ) + filtered = [t for t in texts if t not in calib_texts][:hs] + all_holdout_texts.extend(filtered) + + # 5. Tokenize and build dataloader (mirrors get_dataset_dataloader logic). + tok = copy.deepcopy(tokenizer) + batch_encoded = tok( + all_holdout_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ) + if device: + batch_encoded = batch_encoded.to(device) + + class _InputIdsDataset(torch.utils.data.Dataset): + def __init__(self, input_ids): + self.input_ids = input_ids + + def __getitem__(self, idx): + return {"input_ids": self.input_ids[idx]} + + def __len__(self): + return len(self.input_ids) + + return DataLoader( + _InputIdsDataset(batch_encoded["input_ids"]), + batch_size=args.batch_size, + shuffle=False, + ) + + def auto_quantize( args: argparse.Namespace, language_model: torch.nn.Module, @@ -686,11 +764,17 @@ def export_quantized( if mtp_layer_prefixes: full_model._mtp_layer_prefixes = mtp_layer_prefixes - export_hf_checkpoint( - full_model, - export_dir=export_path, - extra_state_dict=mtp_state_dict, - ) + if args.vllm_fakequant_export: + export_hf_vllm_fq_checkpoint( + full_model, + export_dir=export_path, + ) + else: + export_hf_checkpoint( + full_model, + export_dir=export_path, + extra_state_dict=mtp_state_dict, + ) # Restore default padding and export the tokenizer as well. if tokenizer is not None: @@ -747,7 +831,7 @@ def pre_quantize( allow_fallback=False, ) else: - generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=2) return preview_input_ids, generated_ids_before_ptq @@ -786,7 +870,7 @@ def post_quantize( pass elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=2) elif is_nemotron_vl_model and tokenizer is not None: generated_ids_after_ptq = run_nemotron_vl_preview( full_model, @@ -910,6 +994,9 @@ def quantize_main( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) + mse_logger = None + mse_data = None + if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( "Auto quantization needs multiple quantization format." @@ -937,10 +1024,17 @@ def quantize_main( "Plain quantization supports only one quantization format." ) - assert args.qformat in QUANT_CFG_CHOICES, ( - f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}" - ) - quant_cfg = QUANT_CFG_CHOICES[args.qformat] + if args.qformat in QUANT_CFG_CHOICES: + quant_cfg = QUANT_CFG_CHOICES[args.qformat] + else: + # Fallback: resolve dynamically registered configs from the mtq namespace + # (e.g., PSX LUTS configs registered by modelopt-internal plugins). + quant_cfg = getattr(mtq, args.qformat, None) + assert quant_cfg is not None, ( + f"Unsupported quantization format: {args.qformat}, " + f"not found in built-in choices {list(QUANT_CFG_CHOICES.keys())} " + f"or in the mtq namespace (check that the required plugin is installed)." + ) quant_cfg = build_quant_cfg( args.qformat, @@ -976,7 +1070,23 @@ def quantize_main( quant_cfg = copy.deepcopy(quant_cfg) _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) - if args.qformat in QUANT_CFG_CHOICES: + # Collect original (unquantized) activations before quantization modifies the model + if args.measure_activation_mse: + mse_logger = ActivationMSELogger( + max_samples=args.activation_mse_max_samples, + save_dir=args.activation_mse_save_dir, + ) + mse_data = ActivationMSELogger.resolve_data( + input_path=args.activation_mse_input_path, + calib_dataloader=calib_dataloader, + tokenizer=tokenizer, + max_samples=args.activation_mse_max_samples, + max_length=args.calib_seq, + make_holdout_fn=lambda: _make_mse_holdout_dataloader(args, tokenizer, device), + ) + mse_logger.collect(language_model, mse_data, phase="original") + + if args.qformat in QUANT_CFG_CHOICES or hasattr(mtq, args.qformat): mono_quantize( args, quant_cfg, @@ -1002,15 +1112,67 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) - export_quantized( - args, - full_model, - language_model, - model_type, - tokenizer, - default_padding_side, - default_pad_token, - ) + + if mse_logger is not None: + mse_logger.finish(language_model, mse_data) + del mse_logger, mse_data + torch.cuda.empty_cache() + + if args.eval_perplexity and tokenizer is not None: + if args.fold_weights: + print("Folding weights before perplexity evaluation...") + mtq.fold_weight(language_model) + if args.eval_perplexity_input_path: + print(f"Loading perplexity eval data from {args.eval_perplexity_input_path}") + eval_data = torch.load( + args.eval_perplexity_input_path, map_location="cpu", weights_only=True + ) + # Unbatch to [1, seq_len] per element for compute_perplexity batching + unbatched = [] + for t in eval_data: + if not isinstance(t, torch.Tensor): + continue + if t.dim() == 1: + unbatched.append(t.unsqueeze(0)) + elif t.shape[0] > 1: + unbatched.extend(t.unbind(0)) + else: + unbatched.append(t) + eval_data = [t.unsqueeze(0) if t.dim() == 1 else t for t in unbatched] + print(f"Loaded {len(eval_data)} sequences from {args.eval_perplexity_input_path}") + label = args.eval_perplexity_input_path + else: + eval_data = get_wikitext2(tokenizer, args.eval_perplexity_seq_len) + label = "Wikitext-2" + ppl = compute_perplexity(full_model, eval_data) + print(f"{label} perplexity: {ppl:.2f}") + + # Plugin-registered configs (e.g. PSX LUTS from modelopt-internal) are not exportable + # via the standard TRT-LLM / HF export paths. Fall back to save_pretrained(). + if args.qformat not in QUANT_CFG_CHOICES and hasattr(mtq, args.qformat): + export_path = args.export_path + if args.vllm_fakequant_export: + print(f"Exporting vLLM fakequant checkpoint (bf16 weights + amax) to: {export_path}") + export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path) + else: + print( + f"qformat '{args.qformat}' is a plugin-registered config and is not exportable " + f"via the standard export pipeline. Saving with save_pretrained() instead." + ) + full_model.save_pretrained(export_path) + if tokenizer is not None: + tokenizer.save_pretrained(export_path) + print(f"Quantized model saved to: {export_path}") + else: + export_quantized( + args, + full_model, + language_model, + model_type, + tokenizer, + default_padding_side, + default_pad_token, + ) def parse_args() -> argparse.Namespace: @@ -1064,6 +1226,16 @@ def parse_args() -> argparse.Namespace: default=512, ) parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--vllm_fakequant_export", + action="store_true", + default=False, + help=( + "Export bf16 weights and amax values separately for vLLM fakequant serving. " + "Produces a standard HF checkpoint with GPTQ-adjusted weights plus a " + "quant_amax.pth file that can be loaded via AMAX_FILE_PATH in vllm_serve_fakequant.py." + ), + ) parser.add_argument( "--dataset", help=( @@ -1219,6 +1391,70 @@ def parse_args() -> argparse.Namespace: ), ) + parser.add_argument( + "--eval_perplexity", + action=argparse.BooleanOptionalAction, + default=False, + help="Evaluate Wikitext-2 perplexity after quantization (before export).", + ) + parser.add_argument( + "--eval_perplexity_seq_len", + type=int, + default=2048, + help="Sequence length for perplexity evaluation (default: 2048).", + ) + parser.add_argument( + "--eval_perplexity_input_path", + type=str, + default=None, + help=( + "Path to a .pt file containing pre-tokenized evaluation data " + "(List[Tensor], each [1, seq_len]) for perplexity evaluation. " + "When set, this data is used instead of WikiText-2. " + "Compatible with the .pt files produced by create_holdout_mse_inputs.py " + "or --activation_mse_input_path." + ), + ) + parser.add_argument( + "--measure_activation_mse", + action=argparse.BooleanOptionalAction, + default=False, + help="Measure per-layer activation MSE (original vs quantized) after quantization.", + ) + parser.add_argument( + "--activation_mse_max_samples", + type=int, + default=16, + help="Max calibration samples for activation MSE (default: 16).", + ) + parser.add_argument( + "--activation_mse_save_dir", + type=str, + default=None, + help="Directory to save activation MSE results. If not set, results are only printed.", + ) + parser.add_argument( + "--activation_mse_input_path", + type=str, + default=None, + help=( + "Path to frozen MSE input data. Supports two formats:\n" + " .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes " + "with the current model's tokenizer; if not, decodes calibration data to text and saves.\n" + " .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; " + "if not, materializes from calibration data and saves." + ), + ) + parser.add_argument( + "--fold_weights", + action=argparse.BooleanOptionalAction, + default=False, + help=( + "Fold quantized weights before collecting activation MSE. " + "Speeds up the quantized forward pass by replacing weights in-place " + "and disabling fake-quant, but permanently mutates the weights." + ), + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe669..b717b60da5 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -152,6 +152,7 @@ def disable_compilation(model): "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), + "skip_fold_weight": os.environ.get("SKIP_FOLD_WEIGHT", "0") == "1", } @@ -329,7 +330,14 @@ def calibrate_loop(model: Any = None) -> None: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) - mtq.fold_weight(model) + if quant_config["skip_fold_weight"]: + print("Skipping fold_weight (weights already quantized, e.g. from GPTQ export)") + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + module.disable() + else: + mtq.fold_weight(model) + for name, module in model.named_modules(): if name.endswith("weight_quantizer"): assert not module.is_enabled, f"quantizer {name} is still enabled" diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index 25483f2be1..fb14564568 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -76,6 +76,7 @@ "QUANT_CFG", "AMAX_FILE_PATH", "KV_QUANT_CFG", + "SKIP_FOLD_WEIGHT", } RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index cf2336bf4a..be11108b28 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1308,42 +1308,32 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): ) -class GPTQLiteConfig(QuantizeAlgorithmConfig): - """The config for GPTQ lite. +class GPTQConfig(QuantizeAlgorithmConfig): + """The config for GPTQ quantization. - GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. - - GPTQ lite does not perform sequential quantization of layers. This means that the updated - activations are not used to process the next layer. + GPTQ minimizes the layer-wise quantization error by using second-order (Hessian) information + to perform blockwise weight updates that compensate for rounding loss. Layers are quantized + sequentially so that each layer's Hessian is computed from activations that already reflect + the quantization of preceding layers. The default values are taken from the official GPTQ implementation: https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - - """ - method: Literal["gptq_lite"] = ModeloptField("gptq_lite") - percdamp: float | None = ModeloptField( + method: Literal["gptq"] = ModeloptField("gptq") + percdamp: float = ModeloptField( default=0.01, gt=0.0, le=1.0, title="Percentage damping factor.", description="The percentage of average Hessian diagonal used for damping.", ) - block_size: int | None = ModeloptField( + block_size: int = ModeloptField( default=128, title="Block size for GPTQ weight update.", description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) - hessian_state_path: str | None = ModeloptField( - default=None, - title="Path to the Hessian state file.", - description="""The path to the Hessian state file. If hessian path exists, we load from - hessian file instead of recomputing them.""", - ) QuantizeQuantCfgType = dict[ diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index e08efece9a..63b3a7c913 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,7 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, - GPTQLiteConfig, + GPTQConfig, LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, @@ -59,7 +59,7 @@ ) from .model_calib import ( awq, - gptq_lite, + gptq, local_hessian_calibrate, max_calibrate, mse_calibrate, @@ -240,8 +240,8 @@ def wrapped_calib_func( if sequential: if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max"], ( - f"Sequential calibration currently only supports max calibration, got {method}" + assert method in ["max", "gptq"], ( + f"Sequential calibration currently only supports max and gptq calibration, got {method}" ) # Wrap with sequential processing sequential_calibrate( @@ -493,12 +493,12 @@ def restore(self) -> RestoreEntrypoint: @CalibrateModeRegistry.register_mode -class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor): +class GPTQModeDescriptor(BaseCalibrateModeDescriptor): """Mode for GPTQ calibration algorithm.""" @property def config_class(self) -> type[QuantizeAlgorithmConfig]: """Specifies the config class for the mode.""" - return GPTQLiteConfig + return GPTQConfig - _calib_func = gptq_lite + _calib_func = gptq diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc7..257f8f4979 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,7 +16,7 @@ """Calibration utilities.""" import math -import os +import time import warnings from collections.abc import Callable from functools import partial @@ -39,6 +39,7 @@ from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( disable_calib, + disabled_weight_quantizers, enable_fake_quant, enable_quant, enable_weight_access_and_writeback, @@ -1519,27 +1520,6 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, module_name: str): - """Print relative mean squared error between quantized and original weights. - - Computes the Hessian-weighted relative MSE between quantized and original weights, - providing a measure of quantization quality. This metric is adapted from the GPTQ - repository. - - Args: - q (torch.Tensor): Quantized weight tensor - w (torch.Tensor): Original weight tensor - h (torch.Tensor): Hessian matrix used for weighting the error - module_name (str): Name of the module for logging purposes - Note: - Implementation adapted from the GPTQ repository: - https://github.com/IST-DASLab/FP-Quant - """ - delta = q - w - mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}") - - def update_hessian(input, hessian, n_samples): """Update hessian matrix with new input samples using incremental formula. @@ -1549,337 +1529,479 @@ def update_hessian(input, hessian, n_samples): n_samples: Number of samples already processed Returns: Tuple of (updated_hessian, new_sample_count) + + Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. """ - batch_size = input.shape[0] + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] # Incremental averaging: scale down old hessian hessian *= n_samples / (n_samples + batch_size) n_samples += batch_size # Compute outer product: H += (2/n_samples) * X @ X^T - # where X is the flattened input reshaped to (features, batch*seq) - input_flat = input.reshape(-1, input.shape[-1]).t().float() scaled_input = math.sqrt(2 / n_samples) * input_flat hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) return hessian, n_samples -def prepare_hessian_inverse(h, weight, percdamp): - """Prepare inverse Hessian with dead neuron handling and damping. +@torch.no_grad() +def sequential_calibrate( + model: nn.Module, + forward_loop: ForwardLoop, + calib_func: Callable, + **calib_kwargs, +): + """Sequential calibration - a sequential layer-by-layer calibration algorithm. - Args: - h: Hessian matrix to update - weight: Weight tensor to prepare Hessian for - percdamp: Damping percentage for Hessian diagonal - Returns: - h_inv: Inverse Hessian matrix - Implementation adapted from the FP-Quant repository: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + Runs the full model forward per layer but patches decoder layers with a + skip / run / capture strategy so that inter-layer logic in parent modules + (e.g. mask construction) executes naturally without model-specific hooks. """ - h = h.clone() - # Handle dead neurons (zero weight columns) - # Get columns with all zeros in weight - zero_cols = torch.nonzero(weight.eq(0).all(dim=0)).unsqueeze(-1) + if forward_loop is None: + raise ValueError( + "forward_loop must not be None for sequential calibration. " + "Please provide a valid forward_loop callable." + ) - # Zero out entire rows and columns in Hessian for dead neurons - h[zero_cols, :] = 0 - h[:, zero_cols] = 0 - h[zero_cols, zero_cols] = 1 + transformer_layers = LayerActivationCollector.get_decoder_layers(model) + if transformer_layers is None or len(transformer_layers) == 0: + raise ValueError( + "Could not find transformer layers in model. " + "Sequential calibration requires a model with identifiable transformer layers." + ) - # Add damping to diagonal - damp = percdamp * torch.mean(torch.diag(h)) - diag_indices = torch.arange(h.shape[0], device=h.device) - h[diag_indices, diag_indices] += damp + print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + + input_getter = LayerActivationCollector(model) + input_getter._patch_all_layers(decoder_layers=transformer_layers) try: - h = torch.cholesky_inverse(torch.linalg.cholesky(h)) - h_inv = torch.linalg.cholesky(h, upper=True) - except (RuntimeError, torch.linalg.LinAlgError): - print("Warning: Hessian is not positive definite, using identity matrix") - h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - return h_inv + for layer_idx, layer in enumerate(transformer_layers): + print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") + layer_inputs = input_getter.get_input_activations(layer, forward_loop) + + def _layer_forward_loop(m, _inputs=layer_inputs): + for args, kwargs_input in _inputs: + m(*args, **kwargs_input) + calib_func(layer, _layer_forward_loop, **calib_kwargs) -def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): - """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. + del layer_inputs + torch.cuda.empty_cache() + finally: + input_getter._unpatch_all_layers() - Args: - full_weight: The full weight tensor (needed for INT4 quantization) - block_start: Starting column index of the block - block_end: Ending column index of the block - h_inv: Hessian inverse - quantizer: The quantizer to apply - Returns: - quantized_block: Quantized weights for this block - losses: Quantization losses per element - errors: Accumulated errors for propagation - """ - # Extract the block we're working on - block_weight = full_weight[:, block_start:block_end] - block_hinv = h_inv[block_start:block_end, block_start:block_end] - block_size = block_end - block_start + print_rank_0("Sequential calibration completed") - quantized_block = torch.zeros_like(block_weight) - losses = torch.zeros_like(block_weight) - errors = torch.zeros_like(block_weight) - # We perform column-wise update for GPTQ within the block - group_size = 1 +def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. - for group_start in range(0, block_size, group_size): - group_end = min(group_start + group_size, block_size) - group_cols = slice(group_start, group_end) - # Get current column and its Hessian inverse diagonal - weight_col = block_weight[:, group_cols] - hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. - # Quantize using the full weight, then extract the columns we need - quantized_full = quantizer(full_weight) - quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] - quantized_block[:, group_cols] = quantized_cols + Returns the number of quantizers converted. + """ + converted = 0 + for module in model.modules(): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted - # Compute quantization error and loss - error = (weight_col - quantized_cols) / hinv_diag - losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 - errors[:, group_cols] = error - # Propagate error to remaining columns in block - block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] - full_weight[:, block_start:block_end] = block_weight +@torch.no_grad() +def gptq( + model: nn.Module, + forward_loop: ForwardLoop, + percdamp: float = 0.01, + block_size: int = 128, +): + """GPTQ quantization. + + Works in two modes depending on ``use_sequential`` in the config: - return quantized_block, losses, errors + * **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this + function once per decoder layer with updated activations, producing more + accurate Hessian estimates. + * **Non-sequential** (``use_sequential=False``): called once on the full model. + All layers are quantized in parallel from the original activations. + Per-module steps: -def blockwise_weight_update(module, h, block_size, percdamp): - """Update module weights using GPTQ-style blockwise quantization. + 1. ``max_calibrate`` to set amax values from the current activations. + 2. Promote eligible quantizers to ``NVFP4StaticQuantizer`` (two-level scaling). + 3. Collect per-linear-layer Hessian matrices via forward hooks. + 4. Blockwise weight updates using the inverse Hessian to compensate for + rounding error (the core GPTQ column-wise update). Args: - module: Neural network module with weight and weight_quantizer - H: Hessian matrix (d x d) - block_size: Size of blocks to process at once - percdamp: Damping percentage for Hessian diagonal + model: The module to quantize — either the full model or a single decoder + layer when invoked by ``sequential_calibrate``. + forward_loop: Callable that replays calibration inputs through *model*. + percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). + block_size: Block size for GPTQ weight update. """ - weight = module.weight.data.float().clone() - _, num_cols = weight.shape - # Preprocess Hessian: handle dead neurons and add damping - h_inv = prepare_hessian_inverse(h, weight, percdamp) + class GPTQHelper: + """Encapsulates per-module GPTQ state and operations. - # Initialize output tensors - quantized_weight = torch.zeros_like(weight) - losses = torch.zeros_like(weight) + Owns the Hessian, patches the forward during collection, and contains + the blockwise weight-update logic. - # Process weights in blocks - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) + Instance attributes set during ``__init__``: + module, name, hessian, n_samples - quantized_block, block_losses, block_errors = quantize_block( - weight, block_start, block_end, h_inv, module.weight_quantizer - ) - # Store results - quantized_weight[:, block_start:block_end] = quantized_block - losses[:, block_start:block_end] = block_losses + Instance attributes set during ``update_weights``: + weight: float working copy of module weights (mutated in-place by update methods) + h_inv: upper-triangular Cholesky factor of the damped inverse Hessian + """ - # Propagate errors to remaining weights - weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] + CACHE_NAME = "_forward_no_gptq_hessian" - # Print relative mse error - _print_relative_mse_error(quantized_weight, module.weight.float(), h, module.name) - # Update module weights - module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) + def __init__(self, module, name, offload_to_cpu=False): + self.module = module + self.name = name + in_features = module.weight.shape[-1] + device = module.weight.device + if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + device = "cpu" + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.n_samples = 0 + # Set by update_weights(); listed here for documentation. + self.weight: torch.Tensor | None = None + self.h_inv: torch.Tensor | None = None + def setup(self): + """Patch the module's forward to accumulate Hessian during the collection pass.""" + gptq_helper = self + + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + inp = self.input_quantizer(inp) + gptq_helper.hessian, gptq_helper.n_samples = update_hessian( + inp, gptq_helper.hessian, gptq_helper.n_samples + ) + return self._forward_no_gptq_hessian(input, *args, **kwargs) -def gptq_lite( - model: nn.Module, - forward_loop: ForwardLoop | None = None, - percdamp: float = 0.01, - block_size: int = 128, - hessian_state_path: str | None = None, -): - """GPTQ-lite quantization - a simplified GPTQ variant. + bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) - Key differences from GPTQ: - - Layers are quantized in parallel (not sequentially with updated activations) - - Uses group-wise updates instead of column-wise updates + def cleanup(self): + """Unpatch the module's forward method.""" + unpatch_forward_method(self.module, self.CACHE_NAME) - Args: - model: Model to be calibrated. - forward_loop: Callable that forwards calibration data through the model. - percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). - block_size: Block size for GPTQ weight update. - hessian_state_path: Path to save/load Hessian state. If None, compute without saving. - If path exists, load from it. If path doesnt exist then save computed hessians to path. + def free(self): + """Release Hessian and working tensors to reclaim memory.""" + self.hessian = None + self.weight = None + self.h_inv = None - See :class:`GPTQLiteConfig ` for - details on the remaining arguments. + def update_weights(self, block_size, percdamp): + """Run GPTQ blockwise weight update on this module. - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - """ - # Dictionary to store hessian matrices: {layer_name: {"hessian": Tensor, "n_samples": int}} - hessian_state = {} - - def initialize_hessian_state(tensor_mapping): - """Initialize hessian state with zeros.""" - for name, (shape, device) in tensor_mapping.items(): - # Use CPU if GPU memory is tight - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": torch.zeros(shape, dtype=torch.float32, device=target_device), - "n_samples": 0, - } + Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, + logs MSE, and writes the result back to the module. + """ + hessian = self.hessian.to(self.module.weight.device) + self.weight = self.module.weight.data.float().clone() + self._prepare_hessian_inverse(hessian, percdamp) - def load_hessian_state(path, tensor_mapping): - """Load hessian state from file.""" - print_rank_0(f"Loading hessian state from {path}") - loaded_state = torch.load(path, map_location="cpu") + self._blockwise_update(block_size) - for name, (shape, device) in tensor_mapping.items(): - if name not in loaded_state: - raise KeyError(f"Layer '{name}' not found in loaded hessian state") + self._print_mse_error(hessian) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype + ) - # Move to appropriate device based on memory - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": loaded_state[name]["hessian"].to(target_device), - "n_samples": loaded_state[name]["n_samples"], - } + # ------------------------------------------------------------------ + # Quantize helpers — all read from self.module, self.weight, self.h_inv + # ------------------------------------------------------------------ - print_rank_0(f"Successfully loaded hessian state with {len(hessian_state)} layers") + def _prepare_hessian_inverse(self, hessian, percdamp): + """Compute damped inverse Hessian and store as ``self.h_inv``. - def save_hessian_state(path): - """Save hessian state to file.""" - print_rank_0(f"Saving hessian state to {path}") - try: - # Move to CPU for saving - cpu_state = { - name: {"hessian": state["hessian"].cpu(), "n_samples": state["n_samples"]} - for name, state in hessian_state.items() - } + Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the + Hessian before inversion, matching the FP-Quant reference: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + assert self.weight is not None, ( + "_prepare_hessian_inverse called before update_weights()" + ) + h = hessian.clone() + zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) + + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 + + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp + + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + # Retry with 10x more dampening (matches reference implementation) + print_rank_0( + f"Warning: Hessian not positive definite for {self.name}, " + "retrying with 10x dampening" + ) + h[diag_indices, diag_indices] += damp * 10 + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0( + f"Warning: Hessian still not positive definite for {self.name}, " + "using identity matrix" + ) + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) - torch.save(cpu_state, path) - print_rank_0(f"Successfully saved hessian state to {path}") - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") + def _blockwise_update(self, block_size): + """Column-wise GPTQ update using full-matrix QDQ. - def hessian_hook(module, input, output): - """Hook to intercept activations and update hessian matrix.""" - state = hessian_state[module.name] - hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) - hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + For each column, quantizes the full weight matrix via the quantizer and + extracts the quantized column. This is the standard GPTQ approach. - # Phase 1: Collect statistics for quantizers - max_calibrate(model) + For PSX LUTS vector quantizers, uses a two-phase approach: + 1. Compute scales once per outer block via dynamic quantization + 2. Use static (pre-scaled) quantization in the inner loop - # Phase 2: Build tensor mapping for all quantized layers - tensor_mapping = {} - for name, module in model.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - in_features = module.weight.shape[-1] - tensor_mapping[name] = ((in_features, in_features), module.weight.device) - module.name = name # Attach name for easy access in hooks + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. + """ + assert self.weight is not None and self.h_inv is not None, ( + "_blockwise_update called before _prepare_hessian_inverse()" + ) + quantizer = self.module.weight_quantizer + block_sizes = getattr(quantizer, "block_sizes", None) + if block_sizes is not None: + group_size = block_sizes.get(-1) + if group_size is not None and block_size % group_size != 0: + raise ValueError( + f"GPTQ block_size ({block_size}) must be divisible by the quantizer" + f" group_size ({group_size})" + ) - # Phase 3: Load or compute Hessians - hessian_exists = hessian_state_path is not None and os.path.exists(hessian_state_path) - save_hessians = hessian_state_path is not None and not hessian_exists + # Detect PSX LUTS vector quantizer for the fast static-scale path + is_psx_luts_vq = ( + getattr(quantizer, "backend", None) == "psx_luts" + and quantizer.backend_extra_args.get("lut_type", "vector_lut") == "vector_lut" + ) - if hessian_exists: - print_rank_0(f"Loading hessian state from {hessian_state_path}") - load_hessian_state(hessian_state_path, tensor_mapping) - else: - if forward_loop is None: - raise ValueError("forward_loop must be provided when computing Hessians") + if is_psx_luts_vq: + self._blockwise_update_psx_luts(block_size, quantizer) + else: + self._blockwise_update_default(block_size, quantizer) + + def _blockwise_update_default(self, block_size, quantizer): + """Standard GPTQ blockwise update (full QDQ per column).""" + assert self.weight is not None and self.h_inv is not None + num_cols = self.weight.shape[1] + + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start + h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] + + wblk = self.weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols_blk): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = quantizer(wblk) + self.weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + + self.weight[:, block_end:].addmm_( + errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 + ) - # Initialize hessian state - initialize_hessian_state(tensor_mapping) + @staticmethod + def _dynamic_blockwise_vector_quantization( + x, vector_lut, block_size=16, scale_type="e4m3", return_scales=False + ): + """Dynamic VQ: computes scales from input, returns quantized output (and optionally scales).""" + from luts import clip_vector_scalesign_fast - # Register hooks to collect activations - handles = [] - for name, module in model.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - handles.append(module.register_forward_hook(hessian_hook)) + y = clip_vector_scalesign_fast( + x, + vector_lut, + block_size, + scale_type, + scale_algo="max", + sign_scale=True, + return_scales=return_scales, + ) + if return_scales: + return y[0], y[1] + return y - # Run forward loop to compute hessians - print_rank_0("Computing Hessian matrices...") - forward_loop(model) + @staticmethod + def _static_blockwise_vector_quantization(x, vector_lut, scales): + """Static VQ: uses pre-computed scales, returns quantized output.""" + from luts import clip_vector_prescaled - for handle in handles: - handle.remove() + return clip_vector_prescaled(x, vector_lut, scales) - # Save if configured - if save_hessians: - try: - save_hessian_state(hessian_state_path) - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") + def _blockwise_update_psx_luts(self, block_size, quantizer): + """GPTQ blockwise update for PSX LUTS vector quantizers. - # Phase 4: Update weights using computed Hessians - print_rank_0("Updating weights using GPTQ-lite algorithm...") + Uses dynamic_blockwise_vector_quantization to pre-compute scales, + then static_blockwise_vector_quantization inside the GPTQ loop. - quantized_modules = [ - (name, module) - for name, module in model.named_modules() - if is_quantized_linear(module) and module.weight_quantizer.is_enabled - ] + Follows the 3-loop structure from the VQ GPTQ reference + (adaptive_rounding.py: gptq_quantize_scaled_vq). + """ + print_rank_0(f" [{self.name}] Using PSX LUTS GPTQ path") + extra_args = quantizer.backend_extra_args + encode_format = quantizer.num_bits + encode_path = extra_args.get("encode_path", "") + if encode_path and not encode_path.endswith("/"): + encode_path += "/" + quant_block_size = extra_args.get("block_sizes", 16) + scale_type = extra_args.get("scale_type", "e4m3") + print_rank_0( + f"[GPTQ psx_luts] quant_block_size={quant_block_size}, scale_type={scale_type}" + ) - # Perform blockwise weight updates - for name, module in tqdm(quantized_modules, desc="Quantizing layers"): - state = hessian_state[module.name] - hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update(module, hessian, block_size, percdamp) - # Delete hessian state to free memory - del hessian_state[module.name] - torch.cuda.empty_cache() + # Load the vector LUT codebook + import luts - print_rank_0("GPTQ-lite quantization completed successfully") + if "sorted" not in encode_format: + values, _ = luts.encode(encode_format, path=encode_path, norm=False, cuda=True) + else: + sorted_codebook = torch.load( + encode_path + encode_format + ".pt", map_location="cpu" + ) + values = sorted_codebook["sorted_values"].cuda() + values = values.to(torch.float) + vector_size = values.shape[1] + print_rank_0( + f"[GPTQ psx_luts] vector_size={vector_size}, codebook_shape={values.shape}" + ) + assert self.weight is not None and self.h_inv is not None + out_features, num_cols = self.weight.shape -@torch.no_grad() -def sequential_calibrate( - model: nn.Module, - forward_loop: ForwardLoop, - calib_func: Callable, - **calib_kwargs, -): - """Sequential calibration - a sequential layer-by-layer calibration algorithm. + assert block_size % quant_block_size == 0, ( + f"GPTQ block_size ({block_size}) must be a multiple of " + f"quant_block_size ({quant_block_size})" + ) - Runs the full model forward per layer but patches decoder layers with a - skip / run / capture strategy so that inter-layer logic in parent modules - (e.g. mask construction) executes naturally without model-specific hooks. - """ - if forward_loop is None: - raise ValueError( - "forward_loop must not be None for sequential calibration. " - "Please provide a valid forward_loop callable." - ) + # Outside GPTQ loop: dynamic quantization to get scales + _, scales = self._dynamic_blockwise_vector_quantization( + self.weight, + values, + block_size=quant_block_size, + scale_type=scale_type, + return_scales=True, + ) - transformer_layers = LayerActivationCollector.get_decoder_layers(model) - if transformer_layers is None or len(transformer_layers) == 0: - raise ValueError( - "Could not find transformer layers in model. " - "Sequential calibration requires a model with identifiable transformer layers." - ) + # Reshape flat scales to 2D for per-vector-group extraction + n_scale_blocks_per_row = num_cols // quant_block_size + scales_2d = scales.reshape(out_features, n_scale_blocks_per_row) + + w = self.weight.clone() + q = torch.zeros_like(w) + h_inv = self.h_inv + + for i in range(0, num_cols, block_size): + j_end = min(i + block_size, num_cols) + e = torch.zeros(out_features, j_end - i, dtype=w.dtype, device=w.device) + + for j in range(i, j_end, vector_size): + d = min(vector_size, j_end - j) + sb = j // quant_block_size + s = scales_2d[:, sb].contiguous() + + # Inside GPTQ loop: static quantization with pre-computed scales + sub_vec = w[:, j : j + d].contiguous() + if d == vector_size: + q_sub = self._static_blockwise_vector_quantization(sub_vec, values, s) + else: + padded = F.pad(sub_vec, (0, vector_size - d)) + q_sub = self._static_blockwise_vector_quantization(padded, values, s)[:, :d] + + q[:, j : j + d] = q_sub + + for k in range(d): + col = j + k + err = (w[:, col] - q[:, col]) / h_inv[col, col] + e[:, col - i] = err + w[:, col:j_end] -= err.unsqueeze(1) * h_inv[col, col:j_end].unsqueeze(0) + + if j_end < num_cols: + w[:, j_end:] -= e @ h_inv[i:j_end, j_end:] + + self.weight = q + + def _print_mse_error(self, hessian): + """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" + w_orig = self.module.weight.float() + delta = self.weight - w_orig + mse = delta.mm(hessian).mul(delta).mean() / ( + w_orig.mm(hessian).mul(w_orig).mean() + 1e-6 + ) + suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" + print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") - print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + total_start = time.time() - input_getter = LayerActivationCollector(model) - input_getter._patch_all_layers(decoder_layers=transformer_layers) + max_calibrate(model, forward_loop=forward_loop) + _promote_nvfp4_static_quantizers(model) - try: - for layer_idx, layer in enumerate(transformer_layers): - print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") - layer_inputs = input_getter.get_input_activations(layer, forward_loop) + quantized_layers = [ + (n, m) + for n, m in model.named_modules() + if is_quantized_linear(m) and m.weight_quantizer.is_enabled + ] + if not quantized_layers: + print_rank_0("No quantized linear layers found, skipping GPTQ") + return - def _layer_forward_loop(m, _inputs=layer_inputs): - for args, kwargs_input in _inputs: - m(*args, **kwargs_input) + gptq_handles = {name: GPTQHelper(m, name, offload_to_cpu=True) for name, m in quantized_layers} + for handle in gptq_handles.values(): + handle.setup() - calib_func(layer, _layer_forward_loop, **calib_kwargs) + print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") - del layer_inputs - torch.cuda.empty_cache() - finally: - input_getter._unpatch_all_layers() + with disabled_weight_quantizers(model): + forward_loop(model) + + for handle in gptq_handles.values(): + handle.cleanup() + + print_rank_0("Updating weights using GPTQ algorithm...") + for handle in gptq_handles.values(): + handle.update_weights(block_size, percdamp) + wq = handle.module.weight_quantizer + backend = getattr(wq, "backend", None) + print_rank_0(f" [{handle.name}] weight_quantizer.backend={backend}") + if backend == "psx_luts": + wq.disable() + print_rank_0(f" Disabled weight_quantizer for {handle.name}") + handle.free() + del gptq_handles + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s") diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 4aa1ff46b4..7152efacfc 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -559,9 +559,6 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): # Or use the original result config = mtq.get_auto_quantize_config(search_state) - # [Optional] Customize algorithm if needed - config["algorithm"] = {"method": "gptq_lite", "sequential": True} - # Reuse on the same model (e.g. run a longer calibration pass) model = mtq.quantize(model, config, forward_loop=calibrate_loop) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 4340b8dc1f..9893d4e51a 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -710,6 +710,21 @@ def disable_calib(quantizer): quantizer._if_calib = original_if_calib +@contextmanager +def disabled_weight_quantizers(model: nn.Module): + """Disable weight quantizers during hessian collection.""" + disabled_modules = [] + for module in model.modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + module.weight_quantizer.disable() + disabled_modules.append(module) + try: + yield + finally: + for module in disabled_modules: + module.weight_quantizer.enable() + + @contextmanager def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): """Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule. diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 29e0d8a882..cfd3e2c6bc 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -218,6 +218,7 @@ def get_dataset_samples( apply_chat_template: bool = False, tokenizer: "PreTrainedTokenizerBase | None" = None, split: str | list[str] | None = None, + skip_samples: int = 0, ) -> list[str]: """Load a portion of a dataset with the dataset name and a given size. @@ -240,6 +241,9 @@ def get_dataset_samples( split: Override the split(s) to load. Accepts a single split name or a list. If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for registered datasets, or ``["train"]`` for unregistered datasets. + skip_samples: Number of samples to skip from the beginning of each split + before collecting. The skip count is applied **per split**, so + ``skip_samples=128`` with 4 splits skips 128 from each split. Returns: Samples: The list of samples. @@ -306,12 +310,16 @@ def _preprocess(sample: dict) -> str: samples: list[str] = [] for dataset, n in zip(dataset_splits, num_per_split): + collected = 0 for i, sample in enumerate(dataset): - if i >= n: + if i < skip_samples: + continue + if collected >= n: break text = _preprocess(sample) if text: samples.append(text) + collected += 1 return samples @@ -340,6 +348,7 @@ def get_dataset_dataloader( device: torch.device | None = None, include_labels: bool = False, apply_chat_template: bool = False, + skip_samples: int = 0, ) -> DataLoader: """Get a dataloader with the dataset name and tokenizer of the target model. @@ -354,6 +363,8 @@ def get_dataset_dataloader( include_labels: Whether to include labels in the dataloader. apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). + skip_samples: Number of samples to skip per split before collecting + (see :func:`get_dataset_samples`). Returns: An instance of dataloader. @@ -380,7 +391,11 @@ def get_dataset_dataloader( all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): samples = get_dataset_samples( - ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer + ds_name, + num_sample, + apply_chat_template=apply_chat_template, + tokenizer=tokenizer, + skip_samples=skip_samples, ) all_samples.extend(samples) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 0c60bcd007..e229c89e5a 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -20,7 +20,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.export.unified_export_hf import _export_quantized_weight +from modelopt.torch.quantization.model_calib import gptq, update_hessian +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader RAND_SEED = 42 @@ -46,8 +48,11 @@ def test_update_hessian(): f"Expected hessian shape ({features}, {features}), got {updated_hessian.shape}" ) - # Verify sample count is updated correctly (incremented by batch_size) - assert new_n_samples == batch_size, f"Expected n_samples={batch_size}, got {new_n_samples}" + # Verify sample count is updated correctly (incremented by total tokens = batch * seq_len) + expected_n_samples = batch_size * seq_len + assert new_n_samples == expected_n_samples, ( + f"Expected n_samples={expected_n_samples}, got {new_n_samples}" + ) # Verify hessian is not all zeros after update assert not torch.allclose(updated_hessian, torch.zeros_like(updated_hessian)), ( @@ -70,22 +75,23 @@ def test_update_hessian(): # Manual calculation: # input_flat shape: (features, batch*seq) = (2, 12), all ones - # scaled_input = sqrt(2/6) * input_flat = sqrt(1/3) * ones(2, 12) - # outer_product = scaled_input @ scaled_input.t() = (2/6) * ones(2,12) @ ones(12,2) = [[4,4], [4,4]] - # Note: The scaling factor is (2/n_samples), so with n_samples=6 and 12 tokens: (2/6)*12 = 4 - expected_hessian = torch.ones(features, features, dtype=torch.float32) * 4.0 + # n_samples = batch * seq = 12 (token count after flattening) + # scaled_input = sqrt(2/12) * ones(2, 12) + # outer_product = (2/12) * ones(2,12) @ ones(12,2) = [[2,2], [2,2]] + expected_n_samples = batch_size * seq_len # 12 tokens + expected_hessian = torch.ones(features, features, dtype=torch.float32) * 2.0 assert torch.allclose(updated_hessian, expected_hessian, atol=1e-5), ( f"Expected hessian {expected_hessian}, got {updated_hessian}" ) - assert new_n_samples == batch_size + assert new_n_samples == expected_n_samples # Test 3: Accumulated hessians - verify equivalence # Processing [6,2,2] in one step should equal processing [2,2,2] three times seq_len = 2 features = 2 - # Process in 3 steps of batch_size=2 + # Process in 3 steps of batch_size=2 (4 tokens each, 12 total) hessian_accumulated = torch.zeros(features, features, dtype=torch.float32) n_samples_accumulated = 0 @@ -102,7 +108,8 @@ def test_update_hessian(): assert torch.allclose(hessian_accumulated, expected_hessian, atol=1e-5), ( f"Accumulated hessian should match expected: expected {expected_hessian}, got {hessian_accumulated}" ) - assert n_samples_accumulated == 6, f"Expected n_samples=6, got {n_samples_accumulated}" + # 3 batches * 2 batch_size * 2 seq_len = 12 tokens + assert n_samples_accumulated == 12, f"Expected n_samples=12, got {n_samples_accumulated}" @pytest.mark.parametrize( @@ -120,35 +127,20 @@ def test_update_hessian(): def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): model = torch.nn.Linear(dim, dim).to("cuda") model.weight.data = model_weight - model.name = "linear" original_weight = model_weight.clone() - input = torch.randn(2, 16, dim).to("cuda") - hessian = torch.zeros(dim, dim).to("cpu") - n_samples = 0 + input_tensor = torch.randn(2, 16, dim).to("cuda") quant_cfg = mtq.NVFP4_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input)) + mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input_tensor)) # Get qdq weight q_dq_weight = model.weight_quantizer(model.weight.data) - # Restore original weight + # Restore original weight before GPTQ model.weight.data = original_weight.clone() - hessian, n_samples = update_hessian(input, hessian, n_samples) - - # Verify n_samples is update using hessian matrix - assert n_samples == input.shape[0], "n_samples should be equal to input.shape[0]" - - # Perform another forward pass to update hessian matrix - input_2 = torch.randn(3, 16, dim).to("cuda") - hessian, n_samples = update_hessian(input_2, hessian, n_samples) - assert n_samples == input.shape[0] + input_2.shape[0], ( - "n_samples should be equal to input.shape[0] + input_2.shape[0]" - ) - - hessian = hessian.to(input.device) - blockwise_weight_update(model, hessian, block_size, 0.1) + # Run GPTQ through the public API + gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) if expect_weight_change: # Weight must change as GPTQ updates weights to adjust for quantization error assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" @@ -156,6 +148,69 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): assert torch.allclose(model.weight.data, q_dq_weight), "Weight should be equal" +def test_gptq_export_roundtrip(): + """Test that GPTQ export + dequantize produces weights matching in-memory QDQ.""" + torch.manual_seed(RAND_SEED) + dim = 128 + block_size = 4 + + # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers + model = torch.nn.Linear(dim, dim).to("cuda") + original_weight = model.weight.data.clone() + input_tensor = torch.randn(2, 16, dim).to("cuda") + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(input_tensor)) + + # Restore original weight before GPTQ + model.weight.data = original_weight.clone() + + # Step 2: Perform GPTQ — compute Hessian and update weights + gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) + + # Save the QDQ reference from the quantizer applied to GPTQ'd weights + gptq_weight_shape = model.weight.data.shape + gptq_weight_dtype = model.weight.data.dtype + qdq_ref = model.weight.data.clone() + + # Step 3: Export — converts weight to packed NVFP4 and registers scale buffers + _export_quantized_weight(model, torch.bfloat16) + + # Verify export produced the expected buffers + assert hasattr(model, "weight_scale"), "Export should register weight_scale buffer" + assert hasattr(model, "weight_scale_2"), "Export should register weight_scale_2 buffer" + + # Step 4: Dequantize the exported packed weight and compare with QDQ reference + packed_weight = model.weight.data + weight_scale = model.weight_scale + weight_scale_2 = model.weight_scale_2 + + nvfp4_qtensor = NVFP4QTensor(gptq_weight_shape, gptq_weight_dtype, packed_weight) + deq_weight = nvfp4_qtensor.dequantize( + dtype=torch.bfloat16, + scale=weight_scale, + double_scale=weight_scale_2, + block_sizes={-1: 16}, + ) + + assert deq_weight.shape == qdq_ref.shape, ( + f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" + ) + diff = (deq_weight - qdq_ref.to(torch.bfloat16)).abs() + max_diff = diff.max().item() + max_diff_idx = diff.argmax().item() + max_diff_row = max_diff_idx // deq_weight.shape[1] + max_diff_col = max_diff_idx % deq_weight.shape[1] + num_mismatched = (diff > 1e-3).sum().item() + total_elements = diff.numel() + + assert torch.allclose(deq_weight, qdq_ref.to(torch.bfloat16), atol=1e-2), ( + f"Dequantized weight does not match QDQ reference. " + f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}], " + f"mismatched (>1e-3): {num_mismatched}/{total_elements}" + ) + + @pytest.mark.parametrize( "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] ) @@ -179,7 +234,7 @@ def test_gptq_e2e_flow(quant_cfg): model.eval() quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"] = "gptq_lite" + quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} # Define quantizer/dataloader calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", diff --git a/tests/gpu/torch/quantization/test_gptq_vq.py b/tests/gpu/torch/quantization/test_gptq_vq.py new file mode 100644 index 0000000000..8785c95c9d --- /dev/null +++ b/tests/gpu/torch/quantization/test_gptq_vq.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test ModelOpt GPTQ with PSX LUTS VQ on a single expert linear layer.""" + +import copy + +import pytest +import torch + +RTN_CFG_NAME = ( + "PSX_LUTS_WEIGHT_VL_VS8_Entries65536_LFSR_max_sorted_bs16_ACTIVATION_NONE_CFG_routed_moes" +) +GPTQ_CFG_NAME = ( + "GPTQ_PSX_LUTS_WEIGHT_VL_VS8_Entries65536_LFSR_max_sorted_bs16_ACTIVATION_NONE_CFG_routed_moes" +) + + +def _configs_available(): + try: + import modelopt.torch.quantization as mtq + + return getattr(mtq, RTN_CFG_NAME, None) is not None + except Exception: + return False + + +class _SingleExpertModel(torch.nn.Module): + """Wraps a Linear so its path contains 'experts' to match the quant_cfg patterns.""" + + def __init__(self, in_features, out_features): + super().__init__() + self.experts = torch.nn.ModuleList([torch.nn.Linear(in_features, out_features, bias=False)]) + + def forward(self, x): + return self.experts[0](x) + + +@pytest.mark.skipif(not _configs_available(), reason="PSX LUTS plugin configs not available") +def test_modelopt_gptq_vs_rtn(): + """GPTQ should produce lower output NMSE than RTN on a single expert layer.""" + import modelopt.torch.quantization as mtq + + rtn_cfg = copy.deepcopy(getattr(mtq, RTN_CFG_NAME)) + gptq_cfg = copy.deepcopy(getattr(mtq, GPTQ_CFG_NAME)) + # Single-layer model has no decoder layers for sequential calibration + gptq_cfg["algorithm"]["use_sequential"] = False + + torch.manual_seed(42) + out_features, in_features, n_samples = 64, 256, 128 + + model = _SingleExpertModel(in_features, out_features).cuda().float() + orig_weight = model.experts[0].weight.data.clone() + calib_data = [torch.randn(1, in_features, device="cuda") for _ in range(n_samples)] + + def forward_loop(m): + for x in calib_data: + m(x) + + # RTN (fold weights so we get the actual QDQ'd values) + rtn_model = mtq.quantize(copy.deepcopy(model), rtn_cfg, forward_loop=forward_loop) + mtq.fold_weight(rtn_model) + rtn_weight = rtn_model.experts[0].weight.data.float() + + # GPTQ + gptq_model = mtq.quantize(copy.deepcopy(model), gptq_cfg, forward_loop=forward_loop) + gptq_weight = gptq_model.experts[0].weight.data.float() + + # Output NMSE + act = torch.cat(calib_data, dim=0).squeeze().T # (in_features, n_samples) + w = orig_weight.float() + ref_norm_sq = (w @ act).norm() ** 2 + nmse_rtn = ((rtn_weight - w) @ act).norm() ** 2 / ref_norm_sq + nmse_gptq = ((gptq_weight - w) @ act).norm() ** 2 / ref_norm_sq + + print(f"\nRTN NMSE: {nmse_rtn:.8f}") + print(f"GPTQ NMSE: {nmse_gptq:.8f}") + print(f"GPTQ gain over RTN: {nmse_rtn / nmse_gptq:.4f}x") + + assert nmse_gptq < nmse_rtn, "GPTQ should beat RTN" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])