-
Notifications
You must be signed in to change notification settings - Fork 350
fix: handle accelerate CPU-offloaded models in FakeQuant export #1194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| # limitations under the License. | ||
| """Export HuggingFace model to vLLM fakequant checkpoint.""" | ||
|
|
||
| import logging | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
@@ -26,6 +27,8 @@ | |
| from modelopt.torch.quantization.utils import get_quantizer_state_dict | ||
| from modelopt.torch.utils import get_unwrapped_name | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| __all__ = ["export_hf_vllm_fq_checkpoint"] | ||
|
|
||
|
|
||
|
|
@@ -38,6 +41,105 @@ def disable_rotate(quantizer: TensorQuantizer): | |
| return False | ||
|
|
||
|
|
||
| def _materialize_offloaded_weights( | ||
| model: nn.Module, | ||
| state_dict: dict[str, torch.Tensor], | ||
| meta_keys: list[str], | ||
| ) -> None: | ||
| """Replace meta tensors in state_dict with actual data from accelerate offload hooks. | ||
|
|
||
| When a model is loaded with ``device_map="auto"`` and some layers are offloaded | ||
| to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those | ||
| layers. This function walks the model's accelerate hooks to retrieve the actual | ||
| weight data and updates state_dict in-place. | ||
| """ | ||
| hook_map: dict[str, tuple] = {} | ||
| for name, module in model.named_modules(): | ||
| hook = getattr(module, "_hf_hook", None) | ||
| if hook is None: | ||
| continue | ||
| hooks = [hook] | ||
| if hasattr(hook, "hooks"): | ||
| hooks = hook.hooks | ||
| for h in hooks: | ||
| if hasattr(h, "weights_map") and h.weights_map is not None: | ||
| prefix = f"{name}." if name else "" | ||
| hook_map[prefix] = (module, h) | ||
| break | ||
|
|
||
| materialized = 0 | ||
| for key in meta_keys: | ||
| for prefix, (module, hook) in hook_map.items(): | ||
| if not key.startswith(prefix): | ||
| continue | ||
| local_key = key[len(prefix) :] | ||
| wmap = hook.weights_map | ||
| if hasattr(wmap, "dataset"): | ||
| lookup_key = wmap.prefix + local_key | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential prefix ordering issue: Consider sorting |
||
| actual_sd = wmap.dataset.state_dict | ||
| else: | ||
| lookup_key = local_key | ||
| actual_sd = wmap | ||
| if lookup_key in actual_sd: | ||
| state_dict[key] = actual_sd[lookup_key].detach().clone() | ||
| materialized += 1 | ||
| break | ||
| else: | ||
| logger.warning("Could not materialize meta tensor for key: %s", key) | ||
|
|
||
| logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys)) | ||
|
|
||
|
|
||
| def _save_clean_checkpoint( | ||
| model: nn.Module, | ||
| clean_sd: dict[str, torch.Tensor], | ||
| export_dir: Path, | ||
| ) -> None: | ||
| """Save clean weights + config directly, bypassing model.save_pretrained(). | ||
|
|
||
| For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)`` | ||
| ignores the provided state_dict and saves from internal state, leaking | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Behavioral change for ALL models: For the vLLM FakeQuant use case this is probably fine (vLLM doesn't need |
||
| quantizer keys. This function saves ``clean_sd`` directly via safetensors | ||
| API, guaranteeing only the intended keys are written. | ||
| """ | ||
| import json | ||
|
|
||
| from huggingface_hub import split_torch_state_dict_into_shards | ||
| from safetensors.torch import save_file | ||
|
|
||
| # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens). | ||
| # safetensors rejects tensors that share underlying storage. | ||
| cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()} | ||
|
|
||
| state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB") | ||
| for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items(): | ||
| shard = {k: cpu_sd[k] for k in tensor_keys} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing |
||
| save_file(shard, str(export_dir / shard_file)) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard)) | ||
|
Comment on lines
+110
to
+118
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shard before cloning tensors to CPU.
💡 Proposed fix- # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
- # safetensors rejects tensors that share underlying storage.
- cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}
-
- state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
+ state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB")
for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
- shard = {k: cpu_sd[k] for k in tensor_keys}
+ # Move only the current shard to CPU to keep peak memory bounded.
+ shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys}
save_file(shard, str(export_dir / shard_file))
logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
@@
logger.info(
"Checkpoint saved: %d weights in %d shard(s)",
- len(cpu_sd),
+ len(clean_sd),
len(state_dict_split.filename_to_tensors),
)Also applies to: 136-140 |
||
|
|
||
| if state_dict_split.is_sharded: | ||
| index = { | ||
| "metadata": state_dict_split.metadata, | ||
| "weight_map": state_dict_split.tensor_to_filename, | ||
| } | ||
| (export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2)) | ||
|
|
||
| if hasattr(model, "config"): | ||
| model.config.save_pretrained(export_dir) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: import json
config_dict = model.config.to_dict()
config_dict.pop("auto_map", None)
(export_dir / "config.json").write_text(json.dumps(config_dict, indent=2))This avoids the read-modify-write cycle and the conditional. Though the current approach is functionally correct. |
||
| config_path = export_dir / "config.json" | ||
| if config_path.exists(): | ||
| config = json.loads(config_path.read_text()) | ||
| if config.pop("auto_map", None): | ||
| config_path.write_text(json.dumps(config, indent=2)) | ||
| logger.info("Saved config.json (auto_map stripped)") | ||
|
|
||
| logger.info( | ||
| "Checkpoint saved: %d weights in %d shard(s)", | ||
| len(cpu_sd), | ||
| len(state_dict_split.filename_to_tensors), | ||
| ) | ||
|
|
||
|
|
||
| def export_hf_vllm_fq_checkpoint( | ||
| model: nn.Module, | ||
| export_dir: Path | str, | ||
|
|
@@ -62,6 +164,18 @@ def export_hf_vllm_fq_checkpoint( | |
| # parameters are never modified. Apply each weight quantizer's fake-quant | ||
| # to the corresponding weight tensor in the copy. | ||
| state_dict = model.state_dict() | ||
|
|
||
| # Handle accelerate-offloaded models: state_dict() returns meta tensors | ||
| # for CPU/disk-offloaded layers. Materialize them from the offload hooks. | ||
| meta_keys = [k for k, v in state_dict.items() if v.is_meta] | ||
| if meta_keys: | ||
| logger.info( | ||
| "Found %d meta tensors in state_dict (accelerate offloading). " | ||
| "Materializing from offload hooks...", | ||
| len(meta_keys), | ||
| ) | ||
| _materialize_offloaded_weights(model, state_dict, meta_keys) | ||
|
|
||
|
Comment on lines
+167
to
+178
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fail fast if any offloaded tensors stay on
💡 Proposed fix if meta_keys:
logger.info(
"Found %d meta tensors in state_dict (accelerate offloading). "
"Materializing from offload hooks...",
len(meta_keys),
)
_materialize_offloaded_weights(model, state_dict, meta_keys)
+ unresolved_meta_keys = [
+ k for k, v in state_dict.items() if v.is_meta and "quantizer" not in k
+ ]
+ if unresolved_meta_keys:
+ shown = ", ".join(unresolved_meta_keys[:10])
+ suffix = " ..." if len(unresolved_meta_keys) > 10 else ""
+ raise RuntimeError(f"Failed to materialize offloaded tensors: {shown}{suffix}")🤖 Prompt for AI Agents |
||
| fakequant_weights = set() | ||
| input_quantizers_folded_pqs = ( | ||
| set() | ||
|
|
@@ -86,6 +200,23 @@ def export_hf_vllm_fq_checkpoint( | |
| ) | ||
| if sd_key in state_dict: | ||
| w = state_dict[sd_key] | ||
| # Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA. | ||
| # Offloaded weights materialized to CPU need a GPU hop. | ||
| if not w.is_cuda: | ||
| # Find a CUDA device: check quantizer buffers/params first, | ||
| # then fall back to sibling tensors on the parent module. | ||
| cuda_dev = None | ||
| for t in list(quantizer.parameters()) + list(quantizer.buffers()): | ||
| if t.is_cuda: | ||
| cuda_dev = t.device | ||
| break | ||
| if cuda_dev is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential silent failure: If no CUDA device is found (e.g., all quantizer buffers/params and module params happen to be on CPU/meta), if cuda_dev is None:
raise RuntimeError(
f"Cannot find CUDA device for quantizer kernel on offloaded weight '{sd_key}'. "
"Ensure at least one quantizer buffer or module parameter is on CUDA."
) |
||
| for t in module.parameters(): | ||
| if t.is_cuda: | ||
| cuda_dev = t.device | ||
| break | ||
| if cuda_dev is not None: | ||
| w = w.to(cuda_dev) | ||
| w_quant = quantizer(w.float()).to(w.dtype).cpu() | ||
| # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) | ||
| # Only valid when input_quantizer does NOT fake-quant activations. If it does | ||
|
|
@@ -161,8 +292,10 @@ def export_hf_vllm_fq_checkpoint( | |
| modelopt_state["modelopt_state_weights"] = quantizer_state_dict | ||
| torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") | ||
|
|
||
| # Step 3: Save HF weights using the pre-built folded state dict. | ||
| model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) | ||
| # Step 3: Save HF weights directly from clean_sd. | ||
| # Bypass model.save_pretrained() because accelerate-offloaded models | ||
| # ignore the state_dict= argument, leaking quantizer keys into safetensors. | ||
| _save_clean_checkpoint(model, clean_sd, export_dir) | ||
|
|
||
| for wq, orig_rotate in wqs_to_restore: | ||
| wq.enable() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code duplication: The hook traversal logic here (checking
_hf_hook, handlingSequentialHook-like hooks with.hooksattribute, and PrefixedDataset access via.dataset.state_dict) is largely duplicated frommodelopt/torch/quantization/plugins/accelerate.py(_get_cpu_offload_hookandweight_access_and_writeback_context). Consider extracting a shared utility, or at minimum importing/reusing_get_cpu_offload_hookto find the relevant AlignDevicesHook for each module.