From 55eaa6efb28c4cee90643b95435b43e432250058 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Jan 2026 16:44:06 +0530 Subject: [PATCH 1/3] style --- src/diffusers/models/modeling_utils.py | 20 ++++++++++---------- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- src/diffusers/utils/flashpack_utils.py | 20 +++++++++----------- src/diffusers/utils/import_utils.py | 5 +++++ 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5b0f8a3a0d64..c89204dff74b 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -709,9 +709,8 @@ def save_pretrained( repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). use_flashpack (`bool`, *optional*, defaults to `False`): - Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. - FlashPack is a binary format that allows for faster loading. - Requires the `flashpack` library to be installed. + Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a + binary format that allows for faster loading. Requires the `flashpack` library to be installed. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -743,6 +742,7 @@ def save_pretrained( if not is_main_process: return from ..utils.flashpack_utils import save_flashpack + save_flashpack( self, save_directory, @@ -953,9 +953,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. use_flashpack (`bool`, *optional*, defaults to `False`): - If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file - is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to - the standard loading path (for example, `safetensors`). + If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) + weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack` + file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`). disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -1279,7 +1279,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P logger.warning( "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) - + if resolved_model_file is None and not is_sharded: resolved_model_file = _get_model_file( pretrained_model_name_or_path, @@ -1323,12 +1323,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if flashpack_file is not None: from ..utils.flashpack_utils import load_flashpack + # Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing # the model with meta tensors. Since FlashPack cannot write into meta tensors, we # explicitly materialize parameters before loading to ensure correctness and parity # with the standard loading path. if any(p.device.type == "meta" for p in model.parameters()): - model.to_empty(device="cpu") + model.to_empty(device="cpu") load_flashpack(model, flashpack_file) model.register_to_config(_name_or_path=pretrained_model_name_or_path) model.eval() @@ -1434,12 +1435,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if output_loading_info: return model, loading_info - + logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully") return model - # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 34e42f42862f..21522216061f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -270,8 +270,8 @@ class implements both a save and loading method. The pipeline is easily reloaded repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). use_flashpack (`bool`, *optional*, defaults to `False`): - Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install - flashpack`. + Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip + install flashpack`. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ diff --git a/src/diffusers/utils/flashpack_utils.py b/src/diffusers/utils/flashpack_utils.py index 14031a7c543a..821fe5e7fd4e 100644 --- a/src/diffusers/utils/flashpack_utils.py +++ b/src/diffusers/utils/flashpack_utils.py @@ -1,12 +1,15 @@ import json import os from typing import Optional + +from ..utils import _add_variant from .import_utils import is_flashpack_available from .logging import get_logger -from ..utils import _add_variant + logger = get_logger(__name__) + def save_flashpack( model, save_directory: str, @@ -54,30 +57,25 @@ def save_flashpack( json.dump(config_data, f, indent=4) except Exception as config_err: - logger.warning( - f"FlashPack weights saved, but config serialization failed: {config_err}" - ) + logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}") except Exception as e: logger.error(f"Failed to save weights in FlashPack format: {e}") raise + def load_flashpack(model, flashpack_file: str): """ Assign FlashPack weights from a file into an initialized PyTorch model. """ if not is_flashpack_available(): - raise ImportError( - "FlashPack weights require the `flashpack` package. " - "Install with `pip install flashpack`." - ) + raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.") from flashpack import assign_from_file + logger.warning(f"Loading FlashPack weights from {flashpack_file}") try: assign_from_file(model, flashpack_file) except Exception as e: - raise RuntimeError( - f"Failed to load FlashPack weights from {flashpack_file}" - ) from e \ No newline at end of file + raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2b99e42a26f7..af6df925d72e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -233,6 +233,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _av_available, _av_version = _is_package_available("av") _flashpack_available, _flashpack_version = _is_package_available("flashpack") + def is_torch_available(): return _torch_available @@ -424,9 +425,11 @@ def is_kornia_available(): def is_av_available(): return _av_available + def is_flashpack_available(): return _flashpack_available + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -943,6 +946,7 @@ def is_aiter_version(operation: str, version: str): return False return compare_versions(parse(_aiter_version), operation, version) + @cache def is_flashpack_version(operation: str, version: str): """ @@ -952,6 +956,7 @@ def is_flashpack_version(operation: str, version: str): return False return compare_versions(parse(_flashpack_version), operation, version) + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects From 668f265054d8f5b74f6c83e9e5fb5bd06194014e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Jan 2026 17:02:52 +0530 Subject: [PATCH 2/3] up --- src/diffusers/models/modeling_utils.py | 132 ++++++++++++------------- 1 file changed, 65 insertions(+), 67 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c89204dff74b..dc98af07b7e2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -731,23 +731,8 @@ def save_pretrained( " the logger on the traceback to understand the reason why the quantized model is not serializable." ) - weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME - weights_name = _add_variant(weights_name, variant) - weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( - ".safetensors", "{suffix}.safetensors" - ) - os.makedirs(save_directory, exist_ok=True) - if use_flashpack: - if not is_main_process: - return - from ..utils.flashpack_utils import save_flashpack - save_flashpack( - self, - save_directory, - variant=variant, - ) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", None) @@ -759,67 +744,80 @@ def save_pretrained( # Only save the model itself if we are using distributed training model_to_save = self - # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) - # Save the model - state_dict = model_to_save.state_dict() + if use_flashpack: + if not is_main_process: + return - # Save the model - state_dict_split = split_torch_state_dict_into_shards( - state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern - ) + from ..utils.flashpack_utils import save_flashpack - # Clean the folder from a previous save - if is_main_process: - for filename in os.listdir(save_directory): - if filename in state_dict_split.filename_to_tensors.keys(): - continue - full_filename = os.path.join(save_directory, filename) - if not os.path.isfile(full_filename): - continue - weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") - weights_without_ext = weights_without_ext.replace("{suffix}", "") - filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") - # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 - if ( - filename.startswith(weights_without_ext) - and _REGEX_SHARD.fullmatch(filename_without_ext) is not None - ): - os.remove(full_filename) - - for filename, tensors in state_dict_split.filename_to_tensors.items(): - shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} - filepath = os.path.join(save_directory, filename) - if safe_serialization: - # At some point we will need to deal better with save_function (used for TPU and other distributed - # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) - else: - torch.save(shard, filepath) + save_flashpack(model_to_save, save_directory, variant=variant) + else: + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) - if state_dict_split.is_sharded: - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME - save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) - # Save the index as well - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - logger.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + state_dict = model_to_save.state_dict() + state_dict_split = split_torch_state_dict_into_shards( + state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern ) - else: - path_to_weights = os.path.join(save_directory, weights_name) - logger.info(f"Model weights saved in {path_to_weights}") + # Clean the folder from a previous save + if is_main_process: + for filename in os.listdir(save_directory): + if filename in state_dict_split.filename_to_tensors.keys(): + continue + full_filename = os.path.join(save_directory, filename) + if not os.path.isfile(full_filename): + continue + weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") + weights_without_ext = weights_without_ext.replace("{suffix}", "") + filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + if ( + filename.startswith(weights_without_ext) + and _REGEX_SHARD.fullmatch(filename_without_ext) is not None + ): + os.remove(full_filename) + + # Save each shard + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + filepath = os.path.join(save_directory, filename) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + else: + torch.save(shard, filepath) + + # Save index file if sharded + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + else: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + + # Push to hub if requested (common to both paths) if push_to_hub: # Create a new empty model card and eventually tag it model_card = load_or_create_model_card(repo_id, token=token) From ff26d9ffd5bf21688361d00296dff13a0e4734aa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Jan 2026 17:12:43 +0530 Subject: [PATCH 3/3] up --- src/diffusers/models/modeling_utils.py | 128 +++++++++++++------------ 1 file changed, 65 insertions(+), 63 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index dc98af07b7e2..b29f16065732 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1216,13 +1216,50 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) - else: - flashpack_file = None - if use_flashpack: + flashpack_file = None + if use_flashpack: + try: + flashpack_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant("model.flashpack", variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + except EnvironmentError: + flashpack_file = None + logger.warning( + "`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives." + ) + + if flashpack_file is None: + # in the case it is sharded, we have already the index + if is_sharded: + resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + dduf_entries=dduf_entries, + ) + elif use_safetensors: + logger.warning("Trying to load model weights with safetensors format.") try: - flashpack_file = _get_model_file( + resolved_model_file = _get_model_file( pretrained_model_name_or_path, - weights_name=_add_variant("model.flashpack", variant), + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -1234,68 +1271,33 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P commit_hash=commit_hash, dduf_entries=dduf_entries, ) - except EnvironmentError: - flashpack_file = None - if flashpack_file is None: - # in the case it is sharded, we have already the index - if is_sharded: - resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( - pretrained_model_name_or_path, - index_file, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder or "", - dduf_entries=dduf_entries, + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) - elif use_safetensors: - logger.warning("Trying to load model weights with safetensors format.") - try: - resolved_model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, - ) - except IOError as e: - logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") - if not allow_pickle: - raise - logger.warning( - "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." - ) - - if resolved_model_file is None and not is_sharded: - resolved_model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, - ) + if resolved_model_file is None and not is_sharded: + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) - if not isinstance(resolved_model_file, list): - resolved_model_file = [resolved_model_file] + if not isinstance(resolved_model_file, list): + resolved_model_file = [resolved_model_file] # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype