diff --git a/MANIFEST.in b/MANIFEST.in index 307294a..6e2c284 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ -include dlclive/check_install/* \ No newline at end of file +include dlclive/check_install/* +include dlclive/modelzoo/model_configs/*.yaml +include dlclive/modelzoo/project_configs/*.yaml \ No newline at end of file diff --git a/dlclive/modelzoo/pytorch_model_zoo_export.py b/dlclive/modelzoo/pytorch_model_zoo_export.py index 616857d..f5cc39a 100644 --- a/dlclive/modelzoo/pytorch_model_zoo_export.py +++ b/dlclive/modelzoo/pytorch_model_zoo_export.py @@ -14,7 +14,18 @@ def export_modelzoo_model( detector_name: str | None = None, ) -> None: """ + Export a DeepLabCut Model Zoo model to a single .pt file. + Downloads the model configuration and weights from HuggingFace, bundles them + together (optionally with a detector), and saves as a single torch archive. + Skips export if the output file already exists. + + Args: + export_path: Arbitrary destination path for the exported .pt file. + super_animal: Super animal dataset name (e.g. "superanimal_quadruped"). + model_name: Pose model architecture name (e.g. "resnet_50"). + detector_name: Optional detector model name. If provided, detector + weights are included in the export. """ Path(export_path).parent.mkdir(parents=True, exist_ok=True) if Path(export_path).exists(): diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index 3857d14..f9bf2f7 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -9,6 +9,7 @@ from pathlib import Path from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model +from dlclibrary.dlcmodelzoo.modelzoo_download import _load_model_names as huggingface_model_paths from ruamel.yaml import YAML from dlclive.modelzoo.resolve_config import update_config @@ -49,10 +50,7 @@ def list_available_projects() -> list[str]: def list_available_combinations() -> list[str]: - models = list_available_models() - projects = list_available_projects() - combinations = ["_".join([p, m]) for p in projects for m in models] - return combinations + return list(huggingface_model_paths.keys()) def read_config_as_dict(config_path: str | Path) -> dict: diff --git a/pyproject.toml b/pyproject.toml index 8ce8c13..486e124 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,11 @@ include-package-data = true include = ["dlclive*"] [tool.setuptools.package-data] -dlclive = ["check_install/*"] +dlclive = [ + "check_install/*", + "modelzoo/model_configs/*.yaml", + "modelzoo/project_configs/*.yaml", +] # [tool.ruff] # lint.select = ["E", "F", "B", "I", "UP"]