Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
lightning-sdk >=0.2.11
lightning-utilities
litlogger >=2026.03.17
lightning-sdk >=2026.03.31
lightning-utilities<=0.15.3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple =
include_package_data=True,
zip_safe=False,
keywords=["deep learning", "pytorch", "AI"],
python_requires=">=3.8",
python_requires=">=3.10",
setup_requires=["wheel"],
install_requires=_load_requirements(),
extras_require=_prepare_extras(),
Expand Down
2 changes: 1 addition & 1 deletion src/litmodels/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.8"
__version__ = "0.2.0"
__author__ = "Lightning-AI et al."
__author_email__ = "community@lightning.ai"
__license__ = "Apache-2.0"
Expand Down
4 changes: 2 additions & 2 deletions src/litmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import os

from litlogger.models import download_model, load_model, save_model, upload_model, upload_model_files

from litmodels.__about__ import * # noqa: F401, F403

_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from litmodels.io import download_model, load_model, save_model, upload_model, upload_model_files # noqa: F401

__all__ = ["download_model", "upload_model", "load_model", "save_model"]
12 changes: 9 additions & 3 deletions src/litmodels/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Root package for Input/output."""

from litmodels.io.cloud import download_model_files, upload_model_files # noqa: F401
from litmodels.io.gateway import download_model, load_model, save_model, upload_model
from litlogger.models import (
download_model,
download_model_files,
load_model,
save_model,
upload_model,
upload_model_files,
)

__all__ = ["download_model", "upload_model", "upload_model_files", "load_model", "save_model"]
__all__ = ["download_model", "download_model_files", "upload_model", "upload_model_files", "load_model", "save_model"]
142 changes: 8 additions & 134 deletions src/litmodels/io/cloud.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,10 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
"""Compatibility exports for the vendored litlogger model helpers."""

from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
from lightning_sdk.models import _extend_model_name_with_teamspace, _parse_org_teamspace_model_version
from lightning_sdk.models import delete_model as sdk_delete_model
from lightning_sdk.models import download_model as sdk_download_model
from lightning_sdk.models import upload_model as sdk_upload_model
from litlogger.models.cloud import (
_list_available_teamspaces,
delete_model_version,
download_model_files,
upload_model_files,
)

import litmodels

if TYPE_CHECKING:
from lightning_sdk.models import UploadedModelInfo


_SHOWED_MODEL_LINKS = []


def _print_model_link(name: str, verbose: Union[bool, int]) -> None:
"""Print a stable URL to the uploaded model.

Args:
name: Model registry name. Teamspace defaults may be applied before URL construction.
verbose: Controls printing behavior:
- 0: do not print
- 1: print the link only once for a given model
- 2: always print the link
"""
name = _extend_model_name_with_teamspace(name)
org_name, teamspace_name, model_name, _ = _parse_org_teamspace_model_version(name)

url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}"
msg = f"Model uploaded successfully. Link to the model: '{url}'"
if int(verbose) > 1:
print(msg)
elif url not in _SHOWED_MODEL_LINKS:
print(msg)
_SHOWED_MODEL_LINKS.append(url)


def upload_model_files(
name: str,
path: Union[str, Path, list[Union[str, Path]]],
progress_bar: bool = True,
cloud_account: Optional[str] = None,
verbose: Union[bool, int] = 1,
metadata: Optional[dict[str, str]] = None,
) -> "UploadedModelInfo":
"""Upload local artifact(s) to Lightning Cloud using the SDK.

Args:
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
path: File path, directory path, or list of paths to upload.
progress_bar: Whether to show a progress bar during upload.
cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved.
verbose: Verbosity for printing the model link (0 = no output, 1 = print once, 2 = print always).
metadata: Optional metadata to attach to the model/version. The package version is added automatically.

Returns:
UploadedModelInfo describing the created or updated model version.
"""
if not metadata:
metadata = {}
metadata.update({"litModels": litmodels.__version__})
info = sdk_upload_model(
name=name,
path=path,
progress_bar=progress_bar,
cloud_account=cloud_account,
metadata=metadata,
)
if verbose:
_print_model_link(name, verbose)
return info


def download_model_files(
name: str,
download_dir: Union[str, Path] = ".",
progress_bar: bool = True,
) -> Union[str, list[str]]:
"""Download artifact(s) for a model version using the SDK.

Args:
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
download_dir: Directory where downloaded artifact(s) will be stored. Defaults to the current directory.
progress_bar: Whether to show a progress bar during download.

Returns:
str | list[str]: Absolute path(s) to the downloaded artifact(s).
"""
return sdk_download_model(
name=name,
download_dir=download_dir,
progress_bar=progress_bar,
)


def _list_available_teamspaces() -> dict[str, dict]:
"""List teamspaces available to the authenticated user.

Returns:
dict[str, dict]: Mapping of 'org/teamspace' to a metadata dictionary with details.
"""
from lightning_sdk.api import OrgApi, UserApi
from lightning_sdk.utils import resolve as sdk_resolvers

org_api = OrgApi()
user = sdk_resolvers._get_authed_user()
teamspaces = {}
for ts in UserApi()._get_all_teamspace_memberships(""):
if ts.owner_type == "organization":
org = org_api._get_org_by_id(ts.owner_id)
teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name}
elif ts.owner_type == "user": # todo: check also the name
teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user}
else:
raise RuntimeError(f"Unknown organization type {ts.organization_type}")
return teamspaces


def delete_model_version(
name: str,
version: str,
) -> None:
"""Delete a specific model version from the model store.

Args:
name: Base model registry name in the form 'organization/teamspace/modelname'.
version: Identifier of the version to delete. This argument is required.
"""
sdk_delete_model(name=f"{name}:{version}")
__all__ = ["_list_available_teamspaces", "delete_model_version", "download_model_files", "upload_model_files"]
184 changes: 3 additions & 181 deletions src/litmodels/io/gateway.py
Original file line number Diff line number Diff line change
@@ -1,183 +1,5 @@
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
"""Compatibility exports for the vendored litlogger model helpers."""

from litmodels.io.cloud import download_model_files, upload_model_files
from litmodels.io.utils import _KERAS_AVAILABLE, _PYTORCH_AVAILABLE, dump_pickle, load_pickle
from litlogger.models import download_model, load_model, save_model, upload_model

if _PYTORCH_AVAILABLE:
import torch

if _KERAS_AVAILABLE:
from tensorflow import keras

if TYPE_CHECKING:
from lightning_sdk.models import UploadedModelInfo


def upload_model(
name: str,
model: Union[str, Path],
progress_bar: bool = True,
cloud_account: Optional[str] = None,
verbose: Union[bool, int] = 1,
metadata: Optional[dict[str, str]] = None,
) -> "UploadedModelInfo":
"""Upload a local artifact (file or directory) to Lightning Cloud Models.

Args:
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
If the version is omitted, one may be assigned automatically by the service.
model: Path to a checkpoint file or a directory containing model artifacts.
progress_bar: Whether to show a progress bar during the upload.
cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved.
verbose: Verbosity of informational output (0 = silent, 1 = print link once, 2 = print link always).
metadata: Optional metadata key/value pairs to attach to the uploaded model/version.

Returns:
UploadedModelInfo describing the created or updated model version.

Raises:
ValueError: If `model` is not a filesystem path. For in-memory objects, use `save_model()` instead.
"""
if not isinstance(model, (str, Path)):
raise ValueError(
"The `model` argument should be a path to a file or folder, not an python object."
" For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead."
)

return upload_model_files(
path=model,
name=name,
progress_bar=progress_bar,
cloud_account=cloud_account,
verbose=verbose,
metadata=metadata,
)


def save_model(
name: str,
model: Union["torch.nn.Module", Any],
progress_bar: bool = True,
cloud_account: Optional[str] = None,
staging_dir: Optional[str] = None,
verbose: Union[bool, int] = 1,
metadata: Optional[dict[str, str]] = None,
) -> "UploadedModelInfo":
"""Serialize an in-memory model and upload it to Lightning Cloud Models.

Supported models:
- TorchScript (torch.jit.ScriptModule) → saved as .ts via model.save()
- PyTorch nn.Module → saved as .pth (state_dict via torch.save)
- Keras (tf.keras.Model) → saved as .keras via model.save()
- Any other Python object → saved as .pkl via pickle or joblib

Args:
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
model: The in-memory model instance to serialize and upload.
progress_bar: Whether to show a progress bar during the upload.
cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved.
staging_dir: Optional temporary directory used for serialization. A new temp directory is created if omitted.
verbose: Verbosity of informational output (0 = silent, 1 = print link once, 2 = print link always).
metadata: Optional metadata key/value pairs to attach to the uploaded model/version. Integration markers are
added automatically.

Returns:
UploadedModelInfo describing the created or updated model version.

Raises:
ValueError: If `model` is a path. For file/folder uploads use `upload_model()` instead.
"""
if isinstance(model, (str, Path)):
raise ValueError(
"The `model` argument should be a PyTorch model or a Lightning model, not a path to a file."
" With file or folder path use `upload_model` instead."
)

if not staging_dir:
staging_dir = tempfile.mkdtemp()
# if LightningModule and isinstance(model, LightningModule):
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
# model.save_checkpoint(path)
if _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts")
model.save(path)
elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
torch.save(model.state_dict(), path)
elif _KERAS_AVAILABLE and isinstance(model, keras.models.Model):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.keras")
model.save(path)
else:
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
dump_pickle(model=model, path=path)

if not metadata:
metadata = {}
metadata.update({"litModels.integration": "save_model"})

return upload_model(
model=path,
name=name,
progress_bar=progress_bar,
cloud_account=cloud_account,
verbose=verbose,
metadata=metadata,
)


def download_model(
name: str,
download_dir: Union[str, Path] = ".",
progress_bar: bool = True,
) -> Union[str, list[str]]:
"""Download a model version from Lightning Cloud Models to a local directory.

Args:
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
download_dir: Directory where the artifact(s) will be stored. Defaults to the current working directory.
progress_bar: Whether to show a progress bar during the download.

Returns:
str | list[str]: Absolute path(s) to the downloaded file(s) or directory content.
"""
return download_model_files(
name=name,
download_dir=download_dir,
progress_bar=progress_bar,
)


def load_model(name: str, download_dir: str = ".") -> Any:
"""Download a model and load it into memory based on its file extension.

Supported formats:
- .ts → torch.jit.load
- .keras → keras.models.load_model
- .pkl → pickle/joblib via load_pickle

Args:
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
download_dir: Directory to store the downloaded artifact(s) before loading. Defaults to the current directory.

Returns:
Any: The loaded model object.

Raises:
NotImplementedError: If multiple files are downloaded or the file extension is not supported.
"""
download_paths = download_model(name=name, download_dir=download_dir)
# filter out all Markdown, TXT and RST files
download_paths = [p for p in download_paths if Path(p).suffix.lower() not in {".md", ".txt", ".rst"}]
if len(download_paths) > 1:
raise NotImplementedError("Downloaded model with multiple files is not supported yet.")
model_path = Path(download_dir) / download_paths[0]
if model_path.suffix.lower() == ".ts":
return torch.jit.load(model_path)
if model_path.suffix.lower() == ".keras":
return keras.models.load_model(model_path)
if model_path.suffix.lower() == ".pkl":
return load_pickle(path=model_path)
raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.")
__all__ = ["download_model", "load_model", "save_model", "upload_model"]
Loading
Loading