Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def determine_available_memory(self) -> int:
with disable_compilation(model):
return super().determine_available_memory()

def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
if (
quant_config["quant_cfg"]
or quant_config["kv_quant_cfg"]
or quant_config["modelopt_state_path"]
):
_fakequant_run_prolog_worker(self)
super().compile_or_warm_up_model()
return super().compile_or_warm_up_model()
137 changes: 135 additions & 2 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Export HuggingFace model to vLLM fakequant checkpoint."""

import logging
from pathlib import Path

import torch
Expand All @@ -26,6 +27,8 @@
from modelopt.torch.quantization.utils import get_quantizer_state_dict
from modelopt.torch.utils import get_unwrapped_name

logger = logging.getLogger(__name__)

__all__ = ["export_hf_vllm_fq_checkpoint"]


Expand All @@ -38,6 +41,105 @@ def disable_rotate(quantizer: TensorQuantizer):
return False


def _materialize_offloaded_weights(
model: nn.Module,
state_dict: dict[str, torch.Tensor],
meta_keys: list[str],
) -> None:
"""Replace meta tensors in state_dict with actual data from accelerate offload hooks.

When a model is loaded with ``device_map="auto"`` and some layers are offloaded
to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those
layers. This function walks the model's accelerate hooks to retrieve the actual
weight data and updates state_dict in-place.
"""
hook_map: dict[str, tuple] = {}
for name, module in model.named_modules():
hook = getattr(module, "_hf_hook", None)
if hook is None:
continue
hooks = [hook]
if hasattr(hook, "hooks"):
hooks = hook.hooks
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code duplication: The hook traversal logic here (checking _hf_hook, handling SequentialHook-like hooks with .hooks attribute, and PrefixedDataset access via .dataset.state_dict) is largely duplicated from modelopt/torch/quantization/plugins/accelerate.py (_get_cpu_offload_hook and weight_access_and_writeback_context). Consider extracting a shared utility, or at minimum importing/reusing _get_cpu_offload_hook to find the relevant AlignDevicesHook for each module.

for h in hooks:
if hasattr(h, "weights_map") and h.weights_map is not None:
prefix = f"{name}." if name else ""
hook_map[prefix] = (module, h)
break

materialized = 0
for key in meta_keys:
for prefix, (module, hook) in hook_map.items():
if not key.startswith(prefix):
continue
local_key = key[len(prefix) :]
wmap = hook.weights_map
if hasattr(wmap, "dataset"):
lookup_key = wmap.prefix + local_key
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential prefix ordering issue: hook_map is a regular dict iterated by insertion order (module tree traversal order). For nested modules, a key like model.layers.0.self_attn.q_proj.weight could match prefix model. before it matches the more specific prefix model.layers.0.self_attn.q_proj.. The first match wins due to break, which could pick the wrong hook/weights_map.

Consider sorting hook_map by prefix length (descending) so longest (most specific) prefixes are checked first, or using a different lookup strategy.

actual_sd = wmap.dataset.state_dict
else:
lookup_key = local_key
actual_sd = wmap
if lookup_key in actual_sd:
state_dict[key] = actual_sd[lookup_key].detach().clone()
materialized += 1
break
else:
logger.warning("Could not materialize meta tensor for key: %s", key)

logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys))


def _save_clean_checkpoint(
model: nn.Module,
clean_sd: dict[str, torch.Tensor],
export_dir: Path,
) -> None:
"""Save clean weights + config directly, bypassing model.save_pretrained().

For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)``
ignores the provided state_dict and saves from internal state, leaking
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavioral change for ALL models: _save_clean_checkpoint now replaces model.save_pretrained() unconditionally — not just for offloaded models. The old save_pretrained() also saved generation_config.json, tokenizer files (if applicable), and ran any save hooks. The new code only saves safetensors + config.json.

For the vLLM FakeQuant use case this is probably fine (vLLM doesn't need generation_config.json from the export dir). But it's worth documenting this behavioral change, or alternatively only using _save_clean_checkpoint when offloading is detected (i.e., when meta_keys is non-empty) and falling back to model.save_pretrained() otherwise.

quantizer keys. This function saves ``clean_sd`` directly via safetensors
API, guaranteeing only the intended keys are written.
"""
import json

from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file

# Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
# safetensors rejects tensors that share underlying storage.
cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}

state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
shard = {k: cpu_sd[k] for k in tensor_keys}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing export_dir.mkdir() call: The export_dir directory must exist before save_file() is called. The caller export_hf_vllm_fq_checkpoint does create it (export_dir.mkdir(parents=True, exist_ok=True)), but _save_clean_checkpoint as a standalone function doesn't ensure this. Consider adding a defensive export_dir.mkdir(parents=True, exist_ok=True) or documenting the precondition.

save_file(shard, str(export_dir / shard_file))
logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
Comment on lines +110 to +118
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Shard before cloning tensors to CPU.

cpu_sd = {k: v.cpu().clone() ...} creates a second full copy of the checkpoint in host RAM before sharding. On the offload path, that doubles peak memory and can OOM the exact large-model exports this change is trying to unblock.

💡 Proposed fix
-    # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
-    # safetensors rejects tensors that share underlying storage.
-    cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}
-
-    state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
+    state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB")
     for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
-        shard = {k: cpu_sd[k] for k in tensor_keys}
+        # Move only the current shard to CPU to keep peak memory bounded.
+        shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys}
         save_file(shard, str(export_dir / shard_file))
         logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
@@
     logger.info(
         "Checkpoint saved: %d weights in %d shard(s)",
-        len(cpu_sd),
+        len(clean_sd),
         len(state_dict_split.filename_to_tensors),
     )

Also applies to: 136-140


if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
(export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2))

if hasattr(model, "config"):
model.config.save_pretrained(export_dir)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: model.config.save_pretrained(export_dir) writes config.json then we immediately read it back, parse JSON, possibly modify, and write again. This is a small inefficiency. You could do:

import json
config_dict = model.config.to_dict()
config_dict.pop("auto_map", None)
(export_dir / "config.json").write_text(json.dumps(config_dict, indent=2))

This avoids the read-modify-write cycle and the conditional. Though the current approach is functionally correct.

config_path = export_dir / "config.json"
if config_path.exists():
config = json.loads(config_path.read_text())
if config.pop("auto_map", None):
config_path.write_text(json.dumps(config, indent=2))
logger.info("Saved config.json (auto_map stripped)")

logger.info(
"Checkpoint saved: %d weights in %d shard(s)",
len(cpu_sd),
len(state_dict_split.filename_to_tensors),
)


def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
Expand All @@ -62,6 +164,18 @@ def export_hf_vllm_fq_checkpoint(
# parameters are never modified. Apply each weight quantizer's fake-quant
# to the corresponding weight tensor in the copy.
state_dict = model.state_dict()

# Handle accelerate-offloaded models: state_dict() returns meta tensors
# for CPU/disk-offloaded layers. Materialize them from the offload hooks.
meta_keys = [k for k, v in state_dict.items() if v.is_meta]
if meta_keys:
logger.info(
"Found %d meta tensors in state_dict (accelerate offloading). "
"Materializing from offload hooks...",
len(meta_keys),
)
_materialize_offloaded_weights(model, state_dict, meta_keys)

Comment on lines +167 to +178
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast if any offloaded tensors stay on meta.

_materialize_offloaded_weights() only logs misses. If a non-quantizer key is still meta here, the later fake-quant fold or _save_clean_checkpoint() will blow up with a much less actionable error. Please re-check the state dict immediately after materialization and raise with the unresolved keys.

💡 Proposed fix
     if meta_keys:
         logger.info(
             "Found %d meta tensors in state_dict (accelerate offloading). "
             "Materializing from offload hooks...",
             len(meta_keys),
         )
         _materialize_offloaded_weights(model, state_dict, meta_keys)
+        unresolved_meta_keys = [
+            k for k, v in state_dict.items() if v.is_meta and "quantizer" not in k
+        ]
+        if unresolved_meta_keys:
+            shown = ", ".join(unresolved_meta_keys[:10])
+            suffix = " ..." if len(unresolved_meta_keys) > 10 else ""
+            raise RuntimeError(f"Failed to materialize offloaded tensors: {shown}{suffix}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 167 - 178,
After calling _materialize_offloaded_weights(model, state_dict, meta_keys)
recompute unresolved_meta = [k for k,v in state_dict.items() if v.is_meta]; if
unresolved_meta is non-empty and contains any keys that are not
quantizer-related (e.g. not containing "quant" or "quantizer"), raise a
RuntimeError listing unresolved_meta and a short message mentioning that
materialization failed and will break subsequent fake-quant folding or
_save_clean_checkpoint; reference the symbols meta_keys,
_materialize_offloaded_weights, state_dict, and _save_clean_checkpoint so the
error helps locate the problem.

fakequant_weights = set()
input_quantizers_folded_pqs = (
set()
Expand All @@ -86,6 +200,23 @@ def export_hf_vllm_fq_checkpoint(
)
if sd_key in state_dict:
w = state_dict[sd_key]
# Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA.
# Offloaded weights materialized to CPU need a GPU hop.
if not w.is_cuda:
# Find a CUDA device: check quantizer buffers/params first,
# then fall back to sibling tensors on the parent module.
cuda_dev = None
for t in list(quantizer.parameters()) + list(quantizer.buffers()):
if t.is_cuda:
cuda_dev = t.device
break
if cuda_dev is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential silent failure: If no CUDA device is found (e.g., all quantizer buffers/params and module params happen to be on CPU/meta), cuda_dev remains None and w stays on CPU. The subsequent quantizer(w.float()) call will likely fail with a cryptic CUDA error deep in the kernel. Consider raising a clear error:

if cuda_dev is None:
    raise RuntimeError(
        f"Cannot find CUDA device for quantizer kernel on offloaded weight '{sd_key}'. "
        "Ensure at least one quantizer buffer or module parameter is on CUDA."
    )

for t in module.parameters():
if t.is_cuda:
cuda_dev = t.device
break
if cuda_dev is not None:
w = w.to(cuda_dev)
w_quant = quantizer(w.float()).to(w.dtype).cpu()
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
# Only valid when input_quantizer does NOT fake-quant activations. If it does
Expand Down Expand Up @@ -161,8 +292,10 @@ def export_hf_vllm_fq_checkpoint(
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")

# Step 3: Save HF weights using the pre-built folded state dict.
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
# Step 3: Save HF weights directly from clean_sd.
# Bypass model.save_pretrained() because accelerate-offloaded models
# ignore the state_dict= argument, leaking quantizer keys into safetensors.
_save_clean_checkpoint(model, clean_sd, export_dir)

for wq, orig_rotate in wqs_to_restore:
wq.enable()
Expand Down
Loading