diff --git a/requirements.txt b/requirements.txt index ddc5da4..56ce879 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -lightning-sdk >=0.2.11 -lightning-utilities +litlogger >=2026.03.17 +lightning-sdk >=2026.03.31 +lightning-utilities<=0.15.3 diff --git a/setup.py b/setup.py index 053a597..dce9bc3 100755 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple = include_package_data=True, zip_safe=False, keywords=["deep learning", "pytorch", "AI"], - python_requires=">=3.8", + python_requires=">=3.10", setup_requires=["wheel"], install_requires=_load_requirements(), extras_require=_prepare_extras(), diff --git a/src/litmodels/__about__.py b/src/litmodels/__about__.py index c71714e..e4bd307 100644 --- a/src/litmodels/__about__.py +++ b/src/litmodels/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.8" +__version__ = "0.2.0" __author__ = "Lightning-AI et al." __author_email__ = "community@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litmodels/__init__.py b/src/litmodels/__init__.py index 479a80b..e407229 100644 --- a/src/litmodels/__init__.py +++ b/src/litmodels/__init__.py @@ -2,11 +2,11 @@ import os +from litlogger.models import download_model, load_model, save_model, upload_model, upload_model_files + from litmodels.__about__ import * # noqa: F401, F403 _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from litmodels.io import download_model, load_model, save_model, upload_model, upload_model_files # noqa: F401 - __all__ = ["download_model", "upload_model", "load_model", "save_model"] diff --git a/src/litmodels/io/__init__.py b/src/litmodels/io/__init__.py index 363f207..4b749c3 100644 --- a/src/litmodels/io/__init__.py +++ b/src/litmodels/io/__init__.py @@ -1,6 +1,12 @@ """Root package for Input/output.""" -from litmodels.io.cloud import download_model_files, upload_model_files # noqa: F401 -from litmodels.io.gateway import download_model, load_model, save_model, upload_model +from litlogger.models import ( + download_model, + download_model_files, + load_model, + save_model, + upload_model, + upload_model_files, +) -__all__ = ["download_model", "upload_model", "upload_model_files", "load_model", "save_model"] +__all__ = ["download_model", "download_model_files", "upload_model", "upload_model_files", "load_model", "save_model"] diff --git a/src/litmodels/io/cloud.py b/src/litmodels/io/cloud.py index 4061c5f..8c36ddb 100644 --- a/src/litmodels/io/cloud.py +++ b/src/litmodels/io/cloud.py @@ -1,136 +1,10 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# http://www.apache.org/licenses/LICENSE-2.0 -# -from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +"""Compatibility exports for the vendored litlogger model helpers.""" -from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL -from lightning_sdk.models import _extend_model_name_with_teamspace, _parse_org_teamspace_model_version -from lightning_sdk.models import delete_model as sdk_delete_model -from lightning_sdk.models import download_model as sdk_download_model -from lightning_sdk.models import upload_model as sdk_upload_model +from litlogger.models.cloud import ( + _list_available_teamspaces, + delete_model_version, + download_model_files, + upload_model_files, +) -import litmodels - -if TYPE_CHECKING: - from lightning_sdk.models import UploadedModelInfo - - -_SHOWED_MODEL_LINKS = [] - - -def _print_model_link(name: str, verbose: Union[bool, int]) -> None: - """Print a stable URL to the uploaded model. - - Args: - name: Model registry name. Teamspace defaults may be applied before URL construction. - verbose: Controls printing behavior: - - 0: do not print - - 1: print the link only once for a given model - - 2: always print the link - """ - name = _extend_model_name_with_teamspace(name) - org_name, teamspace_name, model_name, _ = _parse_org_teamspace_model_version(name) - - url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}" - msg = f"Model uploaded successfully. Link to the model: '{url}'" - if int(verbose) > 1: - print(msg) - elif url not in _SHOWED_MODEL_LINKS: - print(msg) - _SHOWED_MODEL_LINKS.append(url) - - -def upload_model_files( - name: str, - path: Union[str, Path, list[Union[str, Path]]], - progress_bar: bool = True, - cloud_account: Optional[str] = None, - verbose: Union[bool, int] = 1, - metadata: Optional[dict[str, str]] = None, -) -> "UploadedModelInfo": - """Upload local artifact(s) to Lightning Cloud using the SDK. - - Args: - name: Model registry name in the form 'organization/teamspace/modelname[:version]'. - path: File path, directory path, or list of paths to upload. - progress_bar: Whether to show a progress bar during upload. - cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved. - verbose: Verbosity for printing the model link (0 = no output, 1 = print once, 2 = print always). - metadata: Optional metadata to attach to the model/version. The package version is added automatically. - - Returns: - UploadedModelInfo describing the created or updated model version. - """ - if not metadata: - metadata = {} - metadata.update({"litModels": litmodels.__version__}) - info = sdk_upload_model( - name=name, - path=path, - progress_bar=progress_bar, - cloud_account=cloud_account, - metadata=metadata, - ) - if verbose: - _print_model_link(name, verbose) - return info - - -def download_model_files( - name: str, - download_dir: Union[str, Path] = ".", - progress_bar: bool = True, -) -> Union[str, list[str]]: - """Download artifact(s) for a model version using the SDK. - - Args: - name: Model registry name in the form 'organization/teamspace/modelname[:version]'. - download_dir: Directory where downloaded artifact(s) will be stored. Defaults to the current directory. - progress_bar: Whether to show a progress bar during download. - - Returns: - str | list[str]: Absolute path(s) to the downloaded artifact(s). - """ - return sdk_download_model( - name=name, - download_dir=download_dir, - progress_bar=progress_bar, - ) - - -def _list_available_teamspaces() -> dict[str, dict]: - """List teamspaces available to the authenticated user. - - Returns: - dict[str, dict]: Mapping of 'org/teamspace' to a metadata dictionary with details. - """ - from lightning_sdk.api import OrgApi, UserApi - from lightning_sdk.utils import resolve as sdk_resolvers - - org_api = OrgApi() - user = sdk_resolvers._get_authed_user() - teamspaces = {} - for ts in UserApi()._get_all_teamspace_memberships(""): - if ts.owner_type == "organization": - org = org_api._get_org_by_id(ts.owner_id) - teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name} - elif ts.owner_type == "user": # todo: check also the name - teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user} - else: - raise RuntimeError(f"Unknown organization type {ts.organization_type}") - return teamspaces - - -def delete_model_version( - name: str, - version: str, -) -> None: - """Delete a specific model version from the model store. - - Args: - name: Base model registry name in the form 'organization/teamspace/modelname'. - version: Identifier of the version to delete. This argument is required. - """ - sdk_delete_model(name=f"{name}:{version}") +__all__ = ["_list_available_teamspaces", "delete_model_version", "download_model_files", "upload_model_files"] diff --git a/src/litmodels/io/gateway.py b/src/litmodels/io/gateway.py index b0d7d2d..5b00b32 100644 --- a/src/litmodels/io/gateway.py +++ b/src/litmodels/io/gateway.py @@ -1,183 +1,5 @@ -import os -import tempfile -from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +"""Compatibility exports for the vendored litlogger model helpers.""" -from litmodels.io.cloud import download_model_files, upload_model_files -from litmodels.io.utils import _KERAS_AVAILABLE, _PYTORCH_AVAILABLE, dump_pickle, load_pickle +from litlogger.models import download_model, load_model, save_model, upload_model -if _PYTORCH_AVAILABLE: - import torch - -if _KERAS_AVAILABLE: - from tensorflow import keras - -if TYPE_CHECKING: - from lightning_sdk.models import UploadedModelInfo - - -def upload_model( - name: str, - model: Union[str, Path], - progress_bar: bool = True, - cloud_account: Optional[str] = None, - verbose: Union[bool, int] = 1, - metadata: Optional[dict[str, str]] = None, -) -> "UploadedModelInfo": - """Upload a local artifact (file or directory) to Lightning Cloud Models. - - Args: - name: Model registry name in the form 'organization/teamspace/modelname[:version]'. - If the version is omitted, one may be assigned automatically by the service. - model: Path to a checkpoint file or a directory containing model artifacts. - progress_bar: Whether to show a progress bar during the upload. - cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved. - verbose: Verbosity of informational output (0 = silent, 1 = print link once, 2 = print link always). - metadata: Optional metadata key/value pairs to attach to the uploaded model/version. - - Returns: - UploadedModelInfo describing the created or updated model version. - - Raises: - ValueError: If `model` is not a filesystem path. For in-memory objects, use `save_model()` instead. - """ - if not isinstance(model, (str, Path)): - raise ValueError( - "The `model` argument should be a path to a file or folder, not an python object." - " For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead." - ) - - return upload_model_files( - path=model, - name=name, - progress_bar=progress_bar, - cloud_account=cloud_account, - verbose=verbose, - metadata=metadata, - ) - - -def save_model( - name: str, - model: Union["torch.nn.Module", Any], - progress_bar: bool = True, - cloud_account: Optional[str] = None, - staging_dir: Optional[str] = None, - verbose: Union[bool, int] = 1, - metadata: Optional[dict[str, str]] = None, -) -> "UploadedModelInfo": - """Serialize an in-memory model and upload it to Lightning Cloud Models. - - Supported models: - - TorchScript (torch.jit.ScriptModule) → saved as .ts via model.save() - - PyTorch nn.Module → saved as .pth (state_dict via torch.save) - - Keras (tf.keras.Model) → saved as .keras via model.save() - - Any other Python object → saved as .pkl via pickle or joblib - - Args: - name: Model registry name in the form 'organization/teamspace/modelname[:version]'. - model: The in-memory model instance to serialize and upload. - progress_bar: Whether to show a progress bar during the upload. - cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved. - staging_dir: Optional temporary directory used for serialization. A new temp directory is created if omitted. - verbose: Verbosity of informational output (0 = silent, 1 = print link once, 2 = print link always). - metadata: Optional metadata key/value pairs to attach to the uploaded model/version. Integration markers are - added automatically. - - Returns: - UploadedModelInfo describing the created or updated model version. - - Raises: - ValueError: If `model` is a path. For file/folder uploads use `upload_model()` instead. - """ - if isinstance(model, (str, Path)): - raise ValueError( - "The `model` argument should be a PyTorch model or a Lightning model, not a path to a file." - " With file or folder path use `upload_model` instead." - ) - - if not staging_dir: - staging_dir = tempfile.mkdtemp() - # if LightningModule and isinstance(model, LightningModule): - # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt") - # model.save_checkpoint(path) - if _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule): - path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts") - model.save(path) - elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module): - path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth") - torch.save(model.state_dict(), path) - elif _KERAS_AVAILABLE and isinstance(model, keras.models.Model): - path = os.path.join(staging_dir, f"{model.__class__.__name__}.keras") - model.save(path) - else: - path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl") - dump_pickle(model=model, path=path) - - if not metadata: - metadata = {} - metadata.update({"litModels.integration": "save_model"}) - - return upload_model( - model=path, - name=name, - progress_bar=progress_bar, - cloud_account=cloud_account, - verbose=verbose, - metadata=metadata, - ) - - -def download_model( - name: str, - download_dir: Union[str, Path] = ".", - progress_bar: bool = True, -) -> Union[str, list[str]]: - """Download a model version from Lightning Cloud Models to a local directory. - - Args: - name: Model registry name in the form 'organization/teamspace/modelname[:version]'. - download_dir: Directory where the artifact(s) will be stored. Defaults to the current working directory. - progress_bar: Whether to show a progress bar during the download. - - Returns: - str | list[str]: Absolute path(s) to the downloaded file(s) or directory content. - """ - return download_model_files( - name=name, - download_dir=download_dir, - progress_bar=progress_bar, - ) - - -def load_model(name: str, download_dir: str = ".") -> Any: - """Download a model and load it into memory based on its file extension. - - Supported formats: - - .ts → torch.jit.load - - .keras → keras.models.load_model - - .pkl → pickle/joblib via load_pickle - - Args: - name: Model registry name in the form 'organization/teamspace/modelname[:version]'. - download_dir: Directory to store the downloaded artifact(s) before loading. Defaults to the current directory. - - Returns: - Any: The loaded model object. - - Raises: - NotImplementedError: If multiple files are downloaded or the file extension is not supported. - """ - download_paths = download_model(name=name, download_dir=download_dir) - # filter out all Markdown, TXT and RST files - download_paths = [p for p in download_paths if Path(p).suffix.lower() not in {".md", ".txt", ".rst"}] - if len(download_paths) > 1: - raise NotImplementedError("Downloaded model with multiple files is not supported yet.") - model_path = Path(download_dir) / download_paths[0] - if model_path.suffix.lower() == ".ts": - return torch.jit.load(model_path) - if model_path.suffix.lower() == ".keras": - return keras.models.load_model(model_path) - if model_path.suffix.lower() == ".pkl": - return load_pickle(path=model_path) - raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.") +__all__ = ["download_model", "load_model", "save_model", "upload_model"] diff --git a/src/litmodels/io/utils.py b/src/litmodels/io/utils.py index 32f490c..fb25ca3 100644 --- a/src/litmodels/io/utils.py +++ b/src/litmodels/io/utils.py @@ -1,68 +1,19 @@ -import os -import pickle -from collections.abc import Iterator -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Union - -from lightning_utilities import module_available -from lightning_utilities.core.imports import RequirementCache - - -@contextmanager -def _suppress_os_stderr() -> Iterator[None]: - devnull_fd = os.open(os.devnull, os.O_WRONLY) - old_stderr_fd = os.dup(2) - os.dup2(devnull_fd, 2) # redirect stderr (fd 2) to /dev/null - os.close(devnull_fd) - try: - yield - finally: - os.dup2(old_stderr_fd, 2) # restore stderr - os.close(old_stderr_fd) - - -_JOBLIB_AVAILABLE = module_available("joblib") -_PYTORCH_AVAILABLE = module_available("torch") -with _suppress_os_stderr(): - _TENSORFLOW_AVAILABLE = module_available("tensorflow") - _KERAS_AVAILABLE = RequirementCache("tensorflow >=2.0.0") - -if _JOBLIB_AVAILABLE: - import joblib - - -def dump_pickle(model: Any, path: Union[str, Path]) -> None: - """Serialize a Python object to disk using joblib (if available) or pickle. - - Args: - model: The object to serialize. - path: Destination file path. - - Notes: - - Uses joblib with compression (level 7) when available for smaller artifacts. - - Falls back to pickle with the highest protocol otherwise. - """ - if _JOBLIB_AVAILABLE: - joblib.dump(model, filename=path, compress=7) - else: - with open(path, "wb") as fp: - pickle.dump(model, fp, protocol=pickle.HIGHEST_PROTOCOL) - - -def load_pickle(path: Union[str, Path]) -> Any: - """Load a Python object from a joblib/pickle file. - - Args: - path: Path to the serialized artifact. - - Returns: - Any: The deserialized object. - - Warning: - Loading pickle/joblib files can execute arbitrary code. Only open files from trusted sources. - """ - if _JOBLIB_AVAILABLE: - return joblib.load(path) - with open(path, "rb") as fp: - return pickle.load(fp) +"""Compatibility exports for the vendored litlogger model helpers.""" + +from litlogger.models.serialization import ( + _JOBLIB_AVAILABLE, + _KERAS_AVAILABLE, + _PYTORCH_AVAILABLE, + _TENSORFLOW_AVAILABLE, + dump_pickle, + load_pickle, +) + +__all__ = [ + "_JOBLIB_AVAILABLE", + "_KERAS_AVAILABLE", + "_PYTORCH_AVAILABLE", + "_TENSORFLOW_AVAILABLE", + "dump_pickle", + "load_pickle", +] diff --git a/tests/integrations/test_checkpoints.py b/tests/integrations/test_checkpoints.py index a3aec32..2efbf83 100644 --- a/tests/integrations/test_checkpoints.py +++ b/tests/integrations/test_checkpoints.py @@ -4,7 +4,6 @@ import pytest -import litmodels from tests.integrations import _SKIP_IF_LIGHTNING_MISSING, _SKIP_IF_PYTORCHLIGHTNING_MISSING @@ -20,8 +19,8 @@ ) @pytest.mark.parametrize("clear_all_local", [True, False]) @pytest.mark.parametrize("keep_all_uploaded", [True, False]) -@mock.patch("litmodels.io.cloud.sdk_delete_model") -@mock.patch("litmodels.io.cloud.sdk_upload_model") +@mock.patch("litmodels.integrations.checkpoints.delete_model_version") +@mock.patch("litmodels.integrations.checkpoints.upload_model") @mock.patch("litmodels.integrations.checkpoints.Auth") def test_lightning_checkpoint_callback( mock_auth, @@ -67,7 +66,6 @@ def test_lightning_checkpoint_callback( expected_org = expected_model_registry["org"] expected_teamspace = expected_model_registry["teamspace"] expected_model = expected_model_registry["model"] - mock_upload_model.return_value.name = f"{expected_org}/{expected_teamspace}/{expected_model}" monkeypatch.setattr( "litmodels.integrations.checkpoints.LitModelCheckpointMixin.default_model_name", mock.MagicMock(return_value=expected_boring_model), @@ -101,10 +99,8 @@ def test_lightning_checkpoint_callback( assert mock_upload_model.call_args_list == [ mock.call( name=f"{expected_org}/{expected_teamspace}/{expected_model}:{v}", - path=mock.ANY, - progress_bar=True, - cloud_account=None, - metadata={"litModels.integration": LitModelCheckpoint.__name__, "litModels": litmodels.__version__}, + model=mock.ANY, + metadata={"litModels.integration": LitModelCheckpoint.__name__}, ) for v in ("epoch=0-step=64", "epoch=1-step=128") ] @@ -115,12 +111,13 @@ def test_lightning_checkpoint_callback( assert mock_delete_model.call_count == expected_cloud_removals if expected_cloud_removals: mock_delete_model.assert_called_once_with( - name=f"{expected_org}/{expected_teamspace}/{expected_model}:epoch=0-step=64" + name=f"{expected_org}/{expected_teamspace}/{expected_model}", + version="epoch=0-step=64", ) # Verify paths match the expected pattern for call_args in mock_upload_model.call_args_list: - path = call_args[1]["path"] + path = call_args[1]["model"] assert re.match(r".*[/\\]lightning_logs[/\\]version_\d+[/\\]checkpoints[/\\]epoch=\d+-step=\d+\.ckpt$", path) diff --git a/tests/test_io_cloud.py b/tests/test_io_cloud.py index 09ed4b0..0a859de 100644 --- a/tests/test_io_cloud.py +++ b/tests/test_io_cloud.py @@ -1,167 +1,37 @@ -import os -from contextlib import nullcontext -from unittest import mock - -import joblib -import pytest -import torch -import torch.jit as torch_jit -from sklearn import svm -from torch.nn import Module +from litlogger import models as litlogger_models import litmodels -from litmodels import download_model, load_model, save_model -from litmodels.io import upload_model_files -from litmodels.io.utils import _KERAS_AVAILABLE -from tests.integrations import LIT_ORG, LIT_TEAMSPACE - - -@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"]) -@pytest.mark.parametrize("in_studio", [True, False]) -@mock.patch("litmodels.io.cloud.sdk_upload_model") -def test_upload_wrong_model_name(mock_sdk_upload, name, in_studio, monkeypatch): - if in_studio: - # mock env variables as it would run in studio - monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) - monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) - monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock) - monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock) - monkeypatch.setattr("lightning_sdk.teamspace.TeamspaceApi", mock.MagicMock) - monkeypatch.setattr("lightning_sdk.models._get_teamspace", mock.MagicMock) - - in_studio_only_name = in_studio and name == "model-name" - with ( - pytest.raises(ValueError, match=r".*organization/teamspace/model.*") - if not in_studio_only_name - else nullcontext() - ): - upload_model_files(path="path/to/checkpoint", name=name) - - -@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"]) -@pytest.mark.parametrize("in_studio", [True, False]) -def test_download_wrong_model_name(name, in_studio, monkeypatch): - if in_studio: - # mock env variables as it would run in studio - monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) - monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) - monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock) - monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock) - monkeypatch.setattr("lightning_sdk.models.TeamspaceApi", mock.MagicMock) - in_studio_only_name = in_studio and name == "model-name" - with ( - pytest.raises(ValueError, match=r".*organization/teamspace/model.*") - if not in_studio_only_name - else nullcontext() - ): - download_model(name=name) - - -@pytest.mark.parametrize( - ("model", "model_path", "verbose"), - [ - # ("path/to/checkpoint", "path/to/checkpoint", False), - # (BoringModel(), "%s/BoringModel.ckpt"), - (torch_jit.script(Module()), f"%s{os.path.sep}RecursiveScriptModule.ts", True), - (Module(), f"%s{os.path.sep}Module.pth", True), - (svm.SVC(), f"%s{os.path.sep}SVC.pkl", 1), - ], -) -@mock.patch("litmodels.io.cloud.sdk_upload_model") -def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose): - mock_upload_model.return_value.name = "org-name/teamspace/model-name" - - # The lit-logger function is just a wrapper around the SDK function - save_model( - model=model, - name="org-name/teamspace/model-name", - cloud_account="cluster_id", - staging_dir=str(tmp_path), - verbose=verbose, - ) - expected_path = model_path % str(tmp_path) if "%" in model_path else model_path - mock_upload_model.assert_called_once_with( - path=expected_path, - name="org-name/teamspace/model-name", - cloud_account="cluster_id", - progress_bar=True, - metadata={"litModels": litmodels.__version__, "litModels.integration": "save_model"}, - ) - - -@mock.patch("litmodels.io.cloud.sdk_download_model") -def test_download_model(mock_download_model): - # The lit-logger function is just a wrapper around the SDK function - download_model( - name="org-name/teamspace/model-name", - download_dir="where/to/download", - ) - mock_download_model.assert_called_once_with( - name="org-name/teamspace/model-name", download_dir="where/to/download", progress_bar=True - ) - - -@mock.patch("litmodels.io.cloud.sdk_download_model") -def test_load_model_pickle(mock_download_model, tmp_path): - # create a dummy model file - model_file = tmp_path / "dummy_model.pkl" - test_data = svm.SVC() - joblib.dump(test_data, model_file) - mock_download_model.return_value = [str(model_file.name)] - - # The lit-logger function is just a wrapper around the SDK function - model = load_model( - name="org-name/teamspace/model-name", - download_dir=str(tmp_path), - ) - mock_download_model.assert_called_once_with( - name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True - ) - assert isinstance(model, svm.SVC) - - -@mock.patch("litmodels.io.cloud.sdk_download_model") -def test_load_model_torch_jit(mock_download_model, tmp_path): - # create a dummy model file - model_file = tmp_path / "dummy_model.ts" - test_data = torch_jit.script(Module()) - test_data.save(model_file) - mock_download_model.return_value = [str(model_file.name)] - - # The lit-logger function is just a wrapper around the SDK function - model = load_model( - name="org-name/teamspace/model-name", - download_dir=str(tmp_path), - ) - mock_download_model.assert_called_once_with( - name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True - ) - assert isinstance(model, torch.jit.ScriptModule) - - -@pytest.mark.skipif(not _KERAS_AVAILABLE, reason="TensorFlow/Keras is not available") -@mock.patch("litmodels.io.cloud.sdk_download_model") -def test_load_model_tf_keras(mock_download_model, tmp_path): - from tensorflow import keras - - # create a dummy model file - model_file = tmp_path / "dummy_model.keras" - # Define the model - model = keras.Sequential([ - keras.layers.Dense(10, input_shape=(784,), name="dense_1"), - keras.layers.Dense(10, name="dense_2"), - ]) - model.compile(optimizer="adam", loss="categorical_crossentropy") - model.save(model_file) - # prepare mocked SDK download function - mock_download_model.return_value = [str(model_file.name)] - - # The lit-logger function is just a wrapper around the SDK function - model = load_model( - name="org-name/teamspace/model-name", - download_dir=str(tmp_path), - ) - mock_download_model.assert_called_once_with( - name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True - ) - assert isinstance(model, keras.models.Model) +from litmodels import io as litmodels_io +from litmodels.io import cloud as litmodels_cloud +from litmodels.io import gateway as litmodels_gateway +from litmodels.io import utils as litmodels_utils + + +def test_top_level_exports_are_reexported_from_litlogger(): + assert litmodels.download_model is litlogger_models.download_model + assert litmodels.load_model is litlogger_models.load_model + assert litmodels.save_model is litlogger_models.save_model + assert litmodels.upload_model is litlogger_models.upload_model + assert litmodels.upload_model_files is litlogger_models.upload_model_files + + +def test_io_exports_are_reexported_from_litlogger(): + assert litmodels_io.download_model is litlogger_models.download_model + assert litmodels_io.download_model_files is litlogger_models.download_model_files + assert litmodels_io.load_model is litlogger_models.load_model + assert litmodels_io.save_model is litlogger_models.save_model + assert litmodels_io.upload_model is litlogger_models.upload_model + assert litmodels_io.upload_model_files is litlogger_models.upload_model_files + assert litmodels_gateway.download_model is litlogger_models.download_model + assert litmodels_gateway.load_model is litlogger_models.load_model + assert litmodels_gateway.save_model is litlogger_models.save_model + assert litmodels_gateway.upload_model is litlogger_models.upload_model + + +def test_cloud_and_utils_dependencies_are_reexported(): + assert litmodels_cloud.download_model_files is litlogger_models.download_model_files + assert litmodels_cloud.upload_model_files is litlogger_models.upload_model_files + assert callable(litmodels_cloud.delete_model_version) + assert callable(litmodels_cloud._list_available_teamspaces) + assert callable(litmodels_utils.dump_pickle) + assert callable(litmodels_utils.load_pickle)