From 3895d217de90aecf5a46acdf25fc81783b8cc347 Mon Sep 17 00:00:00 2001 From: Sungsoo Ha Date: Tue, 7 Apr 2026 21:55:43 -0700 Subject: [PATCH 1/4] fix: handle accelerate CPU-offloaded models in FakeQuant export MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When models are loaded with device_map="auto" and layers are offloaded to CPU via AlignDevicesHook, model.state_dict() returns meta tensors and model.save_pretrained(state_dict=clean_sd) is ignored by accelerate. Three fixes: 1. _materialize_offloaded_weights(): resolve meta tensors from accelerate's AlignDevicesHook.weights_map before export. 2. GPU hop in weight processing: move CPU tensors to quantizer's device (quantizer kernels like fp4_fake_quant_block require CUDA). Uses quantizer buffers (amax) for device detection. 3. _save_clean_checkpoint(): bypass save_pretrained entirely, write safetensors directly via save_file() + split_torch_state_dict_into_shards(). Also strips auto_map from config.json (custom code files not in export). 4. FakeQuantWorker.compile_or_warm_up_model: return float (not None) to fix multiproc executor TypeError in max(compilation_times). Tested: Qwen3-0.6B on H100 with forced CPU offloading (500MiB GPU limit, 254/311 meta tensors). All checks passed — no quantizer keys in safetensors, no auto_map in config.json. Signed-off-by: Sungsoo Ha --- examples/vllm_serve/fakequant_worker.py | 4 +- .../torch/export/plugins/vllm_fakequant_hf.py | 126 +++++++++++++++++- 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index ec2b1f4033..1fddecd6ae 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -134,11 +134,11 @@ def determine_available_memory(self) -> int: with disable_compilation(model): return super().determine_available_memory() - def compile_or_warm_up_model(self) -> None: + def compile_or_warm_up_model(self) -> float: if ( quant_config["quant_cfg"] or quant_config["kv_quant_cfg"] or quant_config["modelopt_state_path"] ): _fakequant_run_prolog_worker(self) - super().compile_or_warm_up_model() + return super().compile_or_warm_up_model() diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 1908354a0a..296126ef23 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -14,11 +14,14 @@ # limitations under the License. """Export HuggingFace model to vLLM fakequant checkpoint.""" +import logging from pathlib import Path import torch import torch.nn as nn +logger = logging.getLogger(__name__) + import modelopt.torch.opt as mto from modelopt.torch.quantization.config import RotateConfig from modelopt.torch.quantization.conversion import quantizer_state @@ -38,6 +41,105 @@ def disable_rotate(quantizer: TensorQuantizer): return False +def _materialize_offloaded_weights( + model: nn.Module, + state_dict: dict[str, torch.Tensor], + meta_keys: list[str], +) -> None: + """Replace meta tensors in state_dict with actual data from accelerate offload hooks. + + When a model is loaded with ``device_map="auto"`` and some layers are offloaded + to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those + layers. This function walks the model's accelerate hooks to retrieve the actual + weight data and updates state_dict in-place. + """ + hook_map: dict[str, tuple] = {} + for name, module in model.named_modules(): + hook = getattr(module, "_hf_hook", None) + if hook is None: + continue + hooks = [hook] + if hasattr(hook, "hooks"): + hooks = hook.hooks + for h in hooks: + if hasattr(h, "weights_map") and h.weights_map is not None: + prefix = f"{name}." if name else "" + hook_map[prefix] = (module, h) + break + + materialized = 0 + for key in meta_keys: + for prefix, (module, hook) in hook_map.items(): + if not key.startswith(prefix): + continue + local_key = key[len(prefix):] + wmap = hook.weights_map + if hasattr(wmap, "dataset"): + lookup_key = wmap.prefix + local_key + actual_sd = wmap.dataset.state_dict + else: + lookup_key = local_key + actual_sd = wmap + if lookup_key in actual_sd: + state_dict[key] = actual_sd[lookup_key].detach().clone() + materialized += 1 + break + else: + logger.warning("Could not materialize meta tensor for key: %s", key) + + logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys)) + + +def _save_clean_checkpoint( + model: nn.Module, + clean_sd: dict[str, torch.Tensor], + export_dir: Path, +) -> None: + """Save clean weights + config directly, bypassing model.save_pretrained(). + + For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)`` + ignores the provided state_dict and saves from internal state, leaking + quantizer keys. This function saves ``clean_sd`` directly via safetensors + API, guaranteeing only the intended keys are written. + """ + import json + + from huggingface_hub import split_torch_state_dict_into_shards + from safetensors.torch import save_file + + cpu_sd = {k: v.cpu() if v.device.type != "cpu" else v for k, v in clean_sd.items()} + + state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB") + for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items(): + shard = {k: cpu_sd[k] for k in tensor_keys} + save_file(shard, str(export_dir / shard_file)) + logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard)) + + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + (export_dir / "model.safetensors.index.json").write_text( + json.dumps(index, indent=2) + ) + + if hasattr(model, "config"): + model.config.save_pretrained(export_dir) + config_path = export_dir / "config.json" + if config_path.exists(): + config = json.loads(config_path.read_text()) + if config.pop("auto_map", None): + config_path.write_text(json.dumps(config, indent=2)) + logger.info("Saved config.json (auto_map stripped)") + + logger.info( + "Checkpoint saved: %d weights in %d shard(s)", + len(cpu_sd), + len(state_dict_split.filename_to_tensors), + ) + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, @@ -62,6 +164,18 @@ def export_hf_vllm_fq_checkpoint( # parameters are never modified. Apply each weight quantizer's fake-quant # to the corresponding weight tensor in the copy. state_dict = model.state_dict() + + # Handle accelerate-offloaded models: state_dict() returns meta tensors + # for CPU/disk-offloaded layers. Materialize them from the offload hooks. + meta_keys = [k for k, v in state_dict.items() if v.is_meta] + if meta_keys: + logger.info( + "Found %d meta tensors in state_dict (accelerate offloading). " + "Materializing from offload hooks...", + len(meta_keys), + ) + _materialize_offloaded_weights(model, state_dict, meta_keys) + fakequant_weights = set() input_quantizers_folded_pqs = ( set() @@ -86,6 +200,12 @@ def export_hf_vllm_fq_checkpoint( ) if sd_key in state_dict: w = state_dict[sd_key] + # Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA. + # Offloaded weights materialized to CPU need a GPU hop. + if not w.is_cuda: + qtensors = list(quantizer.parameters()) or list(quantizer.buffers()) + if qtensors and qtensors[0].is_cuda: + w = w.to(qtensors[0].device) w_quant = quantizer(w.float()).to(w.dtype).cpu() # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) # Only valid when input_quantizer does NOT fake-quant activations. If it does @@ -161,8 +281,10 @@ def export_hf_vllm_fq_checkpoint( modelopt_state["modelopt_state_weights"] = quantizer_state_dict torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") - # Step 3: Save HF weights using the pre-built folded state dict. - model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) + # Step 3: Save HF weights directly from clean_sd. + # Bypass model.save_pretrained() because accelerate-offloaded models + # ignore the state_dict= argument, leaking quantizer keys into safetensors. + _save_clean_checkpoint(model, clean_sd, export_dir) for wq, orig_rotate in wqs_to_restore: wq.enable() From b85c4e04191c7d0451b6dac79ea3ad6c503e23a0 Mon Sep 17 00:00:00 2001 From: Sungsoo Ha Date: Tue, 7 Apr 2026 22:38:43 -0700 Subject: [PATCH 2/4] fix: address code quality + CodeRabbit review 1. Fix ruff import ordering: move logger after all imports 2. Clone tensors in _save_clean_checkpoint to handle tied weights (safetensors rejects shared storage) 3. Robust GPU device fallback: check quantizer params/buffers, then parent module params (handles uninitialized quantizers) Signed-off-by: Sungsoo Ha --- .../torch/export/plugins/vllm_fakequant_hf.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 296126ef23..34f4bb72fb 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -20,8 +20,6 @@ import torch import torch.nn as nn -logger = logging.getLogger(__name__) - import modelopt.torch.opt as mto from modelopt.torch.quantization.config import RotateConfig from modelopt.torch.quantization.conversion import quantizer_state @@ -29,6 +27,8 @@ from modelopt.torch.quantization.utils import get_quantizer_state_dict from modelopt.torch.utils import get_unwrapped_name +logger = logging.getLogger(__name__) + __all__ = ["export_hf_vllm_fq_checkpoint"] @@ -107,7 +107,9 @@ def _save_clean_checkpoint( from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file - cpu_sd = {k: v.cpu() if v.device.type != "cpu" else v for k, v in clean_sd.items()} + # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens). + # safetensors rejects tensors that share underlying storage. + cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()} state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB") for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items(): @@ -203,9 +205,20 @@ def export_hf_vllm_fq_checkpoint( # Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA. # Offloaded weights materialized to CPU need a GPU hop. if not w.is_cuda: - qtensors = list(quantizer.parameters()) or list(quantizer.buffers()) - if qtensors and qtensors[0].is_cuda: - w = w.to(qtensors[0].device) + # Find a CUDA device: check quantizer buffers/params first, + # then fall back to sibling tensors on the parent module. + cuda_dev = None + for t in list(quantizer.parameters()) + list(quantizer.buffers()): + if t.is_cuda: + cuda_dev = t.device + break + if cuda_dev is None: + for t in module.parameters(): + if t.is_cuda: + cuda_dev = t.device + break + if cuda_dev is not None: + w = w.to(cuda_dev) w_quant = quantizer(w.float()).to(w.dtype).cpu() # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) # Only valid when input_quantizer does NOT fake-quant activations. If it does From 8c189ac135144723933b4e48c5ebf3210fdbaf00 Mon Sep 17 00:00:00 2001 From: Sungsoo Ha Date: Tue, 7 Apr 2026 22:45:13 -0700 Subject: [PATCH 3/4] style: apply ruff formatting (slice spacing, line wrapping) Signed-off-by: Sungsoo Ha --- modelopt/torch/export/plugins/vllm_fakequant_hf.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 34f4bb72fb..e087694336 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -72,7 +72,7 @@ def _materialize_offloaded_weights( for prefix, (module, hook) in hook_map.items(): if not key.startswith(prefix): continue - local_key = key[len(prefix):] + local_key = key[len(prefix) :] wmap = hook.weights_map if hasattr(wmap, "dataset"): lookup_key = wmap.prefix + local_key @@ -122,9 +122,7 @@ def _save_clean_checkpoint( "metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename, } - (export_dir / "model.safetensors.index.json").write_text( - json.dumps(index, indent=2) - ) + (export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2)) if hasattr(model, "config"): model.config.save_pretrained(export_dir) From 46776ad3b1c43a08c4ca0781ecb2b1fb2d325d9e Mon Sep 17 00:00:00 2001 From: Sungsoo Ha Date: Tue, 14 Apr 2026 10:05:56 -0700 Subject: [PATCH 4/4] fix: harden offload fakequant export and add unit coverage Signed-off-by: Sungsoo Ha --- .../torch/export/plugins/vllm_fakequant_hf.py | 229 +++++++++++------- .../test_vllm_fakequant_hf_export_utils.py | 163 +++++++++++++ 2 files changed, 303 insertions(+), 89 deletions(-) create mode 100644 tests/unit/torch/export/test_vllm_fakequant_hf_export_utils.py diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index e087694336..2d6717ecbb 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -15,6 +15,7 @@ """Export HuggingFace model to vLLM fakequant checkpoint.""" import logging +from collections.abc import Mapping from pathlib import Path import torch @@ -53,33 +54,62 @@ def _materialize_offloaded_weights( layers. This function walks the model's accelerate hooks to retrieve the actual weight data and updates state_dict in-place. """ - hook_map: dict[str, tuple] = {} + hook_entries: list[tuple[str, str, Mapping[str, torch.Tensor]]] = [] + + def _weights_map_from_hook(hook_obj): + """Best-effort extraction of an accelerate weights_map from a hook object.""" + if hasattr(hook_obj, "weights_map") and hook_obj.weights_map is not None: + return hook_obj.weights_map + if hasattr(hook_obj, "hooks"): + for h in hook_obj.hooks: + if hasattr(h, "weights_map") and h.weights_map is not None: + return h.weights_map + return None + + try: + # Reuse accelerate plugin hook resolution instead of duplicating traversal logic. + from modelopt.torch.quantization.plugins.accelerate import _get_cpu_offload_hook + except ImportError: + _get_cpu_offload_hook = None + for name, module in model.named_modules(): hook = getattr(module, "_hf_hook", None) - if hook is None: + if hook is None or _get_cpu_offload_hook is None: continue - hooks = [hook] - if hasattr(hook, "hooks"): - hooks = hook.hooks - for h in hooks: - if hasattr(h, "weights_map") and h.weights_map is not None: - prefix = f"{name}." if name else "" - hook_map[prefix] = (module, h) - break + + align_hook = None + if _get_cpu_offload_hook is not None: + try: + align_hook = _get_cpu_offload_hook(hook) + except AssertionError: + # Some accelerate hook variants do not expose a plain "weight" key. + # Fall back to generic weights_map extraction for export-time readout. + align_hook = None + + wmap = align_hook.weights_map if align_hook is not None else _weights_map_from_hook(hook) + if wmap is None: + continue + + if hasattr(wmap, "dataset"): + weight_prefix = wmap.prefix + actual_sd = wmap.dataset.state_dict + else: + weight_prefix = "" + actual_sd = wmap + + module_prefix = f"{name}." if name else "" + hook_entries.append((module_prefix, weight_prefix, actual_sd)) + + # Match most-specific module prefixes first to avoid ambiguous parent-prefix hits. + hook_entries.sort(key=lambda x: len(x[0]), reverse=True) materialized = 0 for key in meta_keys: - for prefix, (module, hook) in hook_map.items(): - if not key.startswith(prefix): + for module_prefix, weight_prefix, actual_sd in hook_entries: + if not key.startswith(module_prefix): continue - local_key = key[len(prefix) :] - wmap = hook.weights_map - if hasattr(wmap, "dataset"): - lookup_key = wmap.prefix + local_key - actual_sd = wmap.dataset.state_dict - else: - lookup_key = local_key - actual_sd = wmap + local_key = key[len(module_prefix) :] + lookup_key = weight_prefix + local_key if lookup_key in actual_sd: state_dict[key] = actual_sd[lookup_key].detach().clone() materialized += 1 @@ -107,13 +137,13 @@ def _save_clean_checkpoint( from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file - # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens). - # safetensors rejects tensors that share underlying storage. - cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()} + export_dir.mkdir(parents=True, exist_ok=True) - state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB") + state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB") for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items(): - shard = {k: cpu_sd[k] for k in tensor_keys} + # Keep peak memory bounded: move and clone one shard at a time. + # Cloning also breaks shared storage, which safetensors rejects. + shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys} save_file(shard, str(export_dir / shard_file)) logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard)) @@ -125,17 +155,18 @@ def _save_clean_checkpoint( (export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2)) if hasattr(model, "config"): - model.config.save_pretrained(export_dir) + config = model.config.to_dict() config_path = export_dir / "config.json" - if config_path.exists(): - config = json.loads(config_path.read_text()) - if config.pop("auto_map", None): - config_path.write_text(json.dumps(config, indent=2)) - logger.info("Saved config.json (auto_map stripped)") + config_path.write_text(json.dumps(config, indent=2) + "\n") + logger.info("Saved config.json") + + generation_config = getattr(model, "generation_config", None) + if generation_config is not None: + generation_config.save_pretrained(export_dir) logger.info( "Checkpoint saved: %d weights in %d shard(s)", - len(cpu_sd), + len(clean_sd), len(state_dict_split.filename_to_tensors), ) @@ -168,6 +199,7 @@ def export_hf_vllm_fq_checkpoint( # Handle accelerate-offloaded models: state_dict() returns meta tensors # for CPU/disk-offloaded layers. Materialize them from the offload hooks. meta_keys = [k for k, v in state_dict.items() if v.is_meta] + has_offloaded_weights = bool(meta_keys) if meta_keys: logger.info( "Found %d meta tensors in state_dict (accelerate offloading). " @@ -175,6 +207,18 @@ def export_hf_vllm_fq_checkpoint( len(meta_keys), ) _materialize_offloaded_weights(model, state_dict, meta_keys) + unresolved_meta_keys = [ + k + for k, v in state_dict.items() + if v.is_meta and "quantizer" not in k and "quant" not in k + ] + if unresolved_meta_keys: + shown = ", ".join(unresolved_meta_keys[:10]) + suffix = " ..." if len(unresolved_meta_keys) > 10 else "" + raise RuntimeError( + "Failed to materialize offloaded tensors before fake-quant folding / " + f"_save_clean_checkpoint: {shown}{suffix}" + ) fakequant_weights = set() input_quantizers_folded_pqs = ( @@ -203,20 +247,23 @@ def export_hf_vllm_fq_checkpoint( # Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA. # Offloaded weights materialized to CPU need a GPU hop. if not w.is_cuda: - # Find a CUDA device: check quantizer buffers/params first, - # then fall back to sibling tensors on the parent module. + # Find a CUDA device from quantizer/module tensors. cuda_dev = None for t in list(quantizer.parameters()) + list(quantizer.buffers()): if t.is_cuda: cuda_dev = t.device break if cuda_dev is None: - for t in module.parameters(): + for t in list(module.parameters()) + list(module.buffers()): if t.is_cuda: cuda_dev = t.device break - if cuda_dev is not None: - w = w.to(cuda_dev) + if cuda_dev is None: + raise RuntimeError( + "Cannot find CUDA device for quantizer kernel on offloaded weight " + f"'{sd_key}'. Ensure at least one quantizer/module tensor is on CUDA." + ) + w = w.to(cuda_dev) w_quant = quantizer(w.float()).to(w.dtype).cpu() # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) # Only valid when input_quantizer does NOT fake-quant activations. If it does @@ -248,55 +295,59 @@ def export_hf_vllm_fq_checkpoint( # Rotation is also cleared: the weight was already folded with rotation applied, # so if fold_weight is called on reload it must not re-rotate the exported weight. wqs_to_restore = [] - for _, module in model.named_modules(): - if isinstance(module, QuantModule): - for attr_name, quantizer in module.named_children(): - if ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.is_enabled - ): - quantizer.disable() - orig_rotate = quantizer._rotate - if quantizer.rotate_is_enabled: - quantizer._rotate = disable_rotate(quantizer) - wqs_to_restore.append((quantizer, orig_rotate)) - - quantizer_state_dict = get_quantizer_state_dict(model) - for key in list(quantizer_state_dict): - if key.endswith("weight_quantizer"): - # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. - quantizer_state_dict.pop(key) - elif key in input_quantizers_folded_pqs: - # pre_quant_scale was folded into the weight; keep the buffer for strict load but - # save identity so activations are not scaled twice. - qstate_val = quantizer_state_dict[key] - if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: - quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( - qstate_val["_pre_quant_scale"] - ) - modelopt_state = mto.modelopt_state(model) - # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild - # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). - qstate = quantizer_state(model) - for key in list(qstate): - if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): - qstate.pop(key) - - for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): - if mode_str == "quantize" and "metadata" in m_state: - m_state["metadata"]["quantizer_state"] = qstate - break - - # Per-quantizer tensor dict loaded alongside metadata on reload. - modelopt_state["modelopt_state_weights"] = quantizer_state_dict - torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") - - # Step 3: Save HF weights directly from clean_sd. - # Bypass model.save_pretrained() because accelerate-offloaded models - # ignore the state_dict= argument, leaking quantizer keys into safetensors. - _save_clean_checkpoint(model, clean_sd, export_dir) - - for wq, orig_rotate in wqs_to_restore: - wq.enable() - wq._rotate = orig_rotate + try: + for _, module in model.named_modules(): + if isinstance(module, QuantModule): + for attr_name, quantizer in module.named_children(): + if ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.is_enabled + ): + quantizer.disable() + orig_rotate = quantizer._rotate + if quantizer.rotate_is_enabled: + quantizer._rotate = disable_rotate(quantizer) + wqs_to_restore.append((quantizer, orig_rotate)) + + quantizer_state_dict = get_quantizer_state_dict(model) + for key in list(quantizer_state_dict): + if key.endswith("weight_quantizer"): + # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. + quantizer_state_dict.pop(key) + elif key in input_quantizers_folded_pqs: + # pre_quant_scale was folded into the weight; keep the buffer for strict load but + # save identity so activations are not scaled twice. + qstate_val = quantizer_state_dict[key] + if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: + quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( + qstate_val["_pre_quant_scale"] + ) + modelopt_state = mto.modelopt_state(model) + # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild + # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). + qstate = quantizer_state(model) + for key in list(qstate): + if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): + qstate.pop(key) + + for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): + if mode_str == "quantize" and "metadata" in m_state: + m_state["metadata"]["quantizer_state"] = qstate + break + + # Per-quantizer tensor dict loaded alongside metadata on reload. + modelopt_state["modelopt_state_weights"] = quantizer_state_dict + torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") + + # Step 3: Save HF weights. + # Accelerate-offloaded models may ignore state_dict= in save_pretrained() + # and leak quantizer keys, so use manual save only in that case. + if has_offloaded_weights: + _save_clean_checkpoint(model, clean_sd, export_dir) + else: + model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) + finally: + for wq, orig_rotate in wqs_to_restore: + wq.enable() + wq._rotate = orig_rotate diff --git a/tests/unit/torch/export/test_vllm_fakequant_hf_export_utils.py b/tests/unit/torch/export/test_vllm_fakequant_hf_export_utils.py new file mode 100644 index 0000000000..2e2604e412 --- /dev/null +++ b/tests/unit/torch/export/test_vllm_fakequant_hf_export_utils.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import types + +import pytest +import torch +import torch.nn as nn +from modelopt.torch.export.plugins import vllm_fakequant_hf as vllm_fq + + +class _MinimalModel(nn.Module): + def __init__(self, meta_weight: bool = False): + super().__init__() + device = "meta" if meta_weight else "cpu" + self.weight = nn.Parameter(torch.ones(4, device=device)) + self.save_calls = [] + + def save_pretrained(self, export_dir, **kwargs): + self.save_calls.append((export_dir, kwargs)) + + +class _DummyHook: + def __init__(self, weights_map): + self.weights_map = weights_map + + +def _patch_minimal_modelopt_state(monkeypatch): + monkeypatch.setattr(vllm_fq, "get_quantizer_state_dict", lambda _model: {}) + monkeypatch.setattr(vllm_fq, "quantizer_state", lambda _model: {}) + monkeypatch.setattr(vllm_fq.mto, "modelopt_state", lambda _model: {"modelopt_state_dict": []}) + monkeypatch.setattr(vllm_fq.torch, "save", lambda _obj, _path: None) + + +def test_materialize_uses_longest_module_prefix(monkeypatch): + class _NestedModel(nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Module() + self.a.b = nn.Module() + self.a._hf_hook = _DummyHook({"b.weight": torch.tensor([1.0])}) + self.a.b._hf_hook = _DummyHook({"weight": torch.tensor([2.0])}) + + model = _NestedModel() + state_dict = {"a.b.weight": torch.empty(1, device="meta")} + + fake_accel = types.ModuleType("modelopt.torch.quantization.plugins.accelerate") + fake_accel._get_cpu_offload_hook = lambda hook: hook + monkeypatch.setitem(sys.modules, "modelopt.torch.quantization.plugins.accelerate", fake_accel) + + vllm_fq._materialize_offloaded_weights(model, state_dict, ["a.b.weight"]) + assert torch.allclose(state_dict["a.b.weight"], torch.tensor([2.0])) + + +def test_export_raises_if_non_quant_meta_tensors_remain(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + model = _MinimalModel(meta_weight=True) + + monkeypatch.setattr(vllm_fq, "_materialize_offloaded_weights", lambda *_args, **_kwargs: None) + + with ( + torch.inference_mode(), + pytest.raises(RuntimeError, match="Failed to materialize offloaded tensors") as exc, + ): + vllm_fq.export_hf_vllm_fq_checkpoint(model, export_dir=tmp_path / "export_meta_fail") + assert "_save_clean_checkpoint" in str(exc.value) + + +def test_export_uses_model_save_pretrained_when_not_offloaded(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + model = _MinimalModel(meta_weight=False) + called = {"clean": 0} + + def _save_clean_checkpoint(*_args, **_kwargs): + called["clean"] += 1 + + monkeypatch.setattr(vllm_fq, "_save_clean_checkpoint", _save_clean_checkpoint) + vllm_fq.export_hf_vllm_fq_checkpoint(model, export_dir=tmp_path / "export_non_offloaded") + + assert called["clean"] == 0 + assert len(model.save_calls) == 1 + assert model.save_calls[0][1]["save_modelopt_state"] is False + assert "state_dict" in model.save_calls[0][1] + + +def test_export_uses_clean_checkpoint_when_offloaded(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + model = _MinimalModel(meta_weight=True) + called = {"clean": 0} + + def _materialize(_model, state_dict, _meta_keys): + state_dict["weight"] = torch.ones(4) + + def _save_clean_checkpoint(*_args, **_kwargs): + called["clean"] += 1 + + def _unexpected_save_pretrained(*_args, **_kwargs): + raise AssertionError("model.save_pretrained should not be called for offloaded export") + + monkeypatch.setattr(vllm_fq, "_materialize_offloaded_weights", _materialize) + monkeypatch.setattr(vllm_fq, "_save_clean_checkpoint", _save_clean_checkpoint) + model.save_pretrained = _unexpected_save_pretrained + + vllm_fq.export_hf_vllm_fq_checkpoint(model, export_dir=tmp_path / "export_offloaded") + assert called["clean"] == 1 + + +def test_export_raises_when_cuda_device_cannot_be_found(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + + class _DummyTensorQuantizer(nn.Module): + def __init__(self): + super().__init__() + self.fake_quant = True + self.is_enabled = True + self.rotate_is_enabled = False + self._rotate = False + + def disable(self): + self.is_enabled = False + + def enable(self): + self.is_enabled = True + + def forward(self, x): + return x + + class _DummyQuantModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(2, 2)) + self.weight_quantizer = _DummyTensorQuantizer() + + class _DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.block = _DummyQuantModule() + + def save_pretrained(self, _export_dir, **_kwargs): + return None + + monkeypatch.setattr(vllm_fq, "QuantModule", _DummyQuantModule) + monkeypatch.setattr(vllm_fq, "TensorQuantizer", _DummyTensorQuantizer) + + with torch.inference_mode(), pytest.raises( + RuntimeError, match="Cannot find CUDA device for quantizer kernel" + ): + vllm_fq.export_hf_vllm_fq_checkpoint( + _DummyModel(), export_dir=tmp_path / "export_cuda_missing" + )