diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index d94f48fce..a18cbbc16 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -82,6 +82,20 @@ def _remap_key(key_dict: dict[str, Any]): key_dict.update(new_dict) +def remove_quantization_config_from_original_config(export_dir: str) -> None: + """Remove `quantization_config` from exported HF `config.json`. + + Assumes the exported checkpoint directory has a `config.json` containing `quantization_config`. + """ + config_path = os.path.join(export_dir, "config.json") + with open(config_path) as f: + cfg = json.load(f) + del cfg["quantization_config"] + with open(config_path, "w") as f: + json.dump(cfg, f, indent=2, sort_keys=True) + f.write("\n") + + def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): state_dict_list = [ torch.load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt") @@ -302,3 +316,5 @@ def get_tensor(tensor_name): save_root=args.fp4_path, per_layer_quant_config=per_layer_quant_config, ) + + remove_quantization_config_from_original_config(args.fp4_path)