diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index d897ed793376..8a65999b2006 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -5,6 +5,7 @@ import pytest import torch +from huggingface_hub import hf_hub_download import diffusers from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks @@ -32,6 +33,33 @@ ) +def _get_specified_components(path_or_repo_id, cache_dir=None): + if os.path.isdir(path_or_repo_id): + config_path = os.path.join(path_or_repo_id, "modular_model_index.json") + else: + try: + config_path = hf_hub_download( + repo_id=path_or_repo_id, + filename="modular_model_index.json", + local_dir=cache_dir, + ) + except Exception: + return None + + with open(config_path) as f: + config = json.load(f) + + components = set() + for k, v in config.items(): + if isinstance(v, (str, int, float, bool)): + continue + for entry in v: + if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")): + components.add(k) + break + return components + + class ModularPipelineTesterMixin: """ It provides a set of common tests for each modular pipeline, @@ -360,6 +388,39 @@ def test_save_from_pretrained(self, tmp_path): assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_load_expected_components_from_pretrained(self, tmp_path): + pipe = self.get_pipeline() + expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path) + if not expected: + pytest.skip("Skipping test as we couldn't fetch the expected components.") + + actual = { + name + for name in pipe.components + if getattr(pipe, name, None) is not None + and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null") + } + assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}" + + def test_load_expected_components_from_save_pretrained(self, tmp_path): + pipe = self.get_pipeline() + save_dir = str(tmp_path / "saved-pipeline") + pipe.save_pretrained(save_dir) + + expected = _get_specified_components(save_dir) + loaded_pipe = ModularPipeline.from_pretrained(save_dir) + loaded_pipe.load_components(torch_dtype=torch.float32) + + actual = { + name + for name in loaded_pipe.components + if getattr(loaded_pipe, name, None) is not None + and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null") + } + assert expected == actual, ( + f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}" + ) + def test_modular_index_consistency(self, tmp_path): pipe = self.get_pipeline() components_spec = pipe._component_specs