Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 90 additions & 7 deletions zetta_utils/convnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,58 @@
from __future__ import annotations

import io
import os
from typing import Literal, Optional, Sequence, Union, overload

import cachetools
import fsspec
import onnx
import onnx2torch
import torch
import xxhash
from numpy import typing as npt
from typeguard import typechecked

from zetta_utils import builder, log, tensor_ops
from zetta_utils.mazepa import semaphore

logger = log.get_logger("zetta_utils")


TENSORRT_AVAILABLE = False
try:
import torch_tensorrt

TENSORRT_AVAILABLE = True
except ImportError as e:
print(f"torch_tensorrt is not available: {e}")


@builder.register("load_model")
@typechecked
# @typechecked
def load_model(
path: str, device: Union[str, torch.device] = "cpu", use_cache: bool = False
path: str,
device: Union[str, torch.device] = "cpu",
use_cache: bool = False,
input_shape: Sequence[int] | None = None,
tensorrt_enabled: bool = False,
tensorrt_cache_dir: str = ".", # defaults to the current working directory
) -> torch.nn.Module: # pragma: no cover
if use_cache:
result = _load_model_cached(path, device)
result = _load_model_cached(
path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir
)
else:
result = _load_model(path, device)
result = _load_model(path, device, input_shape, tensorrt_enabled, tensorrt_cache_dir)
return result


def _load_model(
path: str, device: Union[str, torch.device] = "cpu"
path: str,
device: Union[str, torch.device] = "cpu",
input_shape: Sequence[int] | None = None,
tensorrt_enabled: bool = False,
tensorrt_cache_dir: str = ".",
) -> torch.nn.Module: # pragma: no cover
logger.debug(f"Loading model from '{path}'")
if path.endswith(".json"):
Expand All @@ -40,8 +63,53 @@ def _load_model(
elif path.endswith(".onnx"):
with fsspec.open(path, "rb") as f:
result = onnx2torch.convert(onnx.load(f)).to(device)
elif path.endswith(".ts"):
# load a cached TensorRT model
result = torch.export.load(path)
else:
raise ValueError(f"Unsupported file format: {path}")

if tensorrt_enabled:
if not TENSORRT_AVAILABLE:
raise RuntimeError("torch_tensorrt is not installed!")

with semaphore("trt_compilation"):
# TensorRT should not be compiled concurrently by many threads or will run out of memory
# Ideally, only by one thread and the others then load the cached model

trt_fname = (
str(xxhash.xxh128(str((path, tuple(input_shape))).encode("utf-8")).hexdigest())
+ ".trt.ts"
)
cache_path = os.path.join(tensorrt_cache_dir, trt_fname)

# Try to load the optimized model from cache
try:
with fsspec.open(cache_path, "rb") as f:
return torch_tensorrt.load(f)
except FileNotFoundError:
print(f"Cache not found. Compiling TensorRT model: {cache_path}")
except Exception as e:
print(f"Error loading TensorRT model from cache: {e}")

example_in = torch.rand(input_shape).to(device=device)

with torch.inference_mode():
trace = torch.jit.trace(result, example_in)

result = torch_tensorrt.ts.compile(
trace,
inputs=[example_in],
truncate_long_and_double=True,
enabled_precisions={torch.float, torch.half},
debug=False,
)

# save optimized model
with fsspec.open(cache_path, "wb") as f:
torch_tensorrt.save(result, f, output_format="torchscript", inputs=[example_in])
print(f"Compiled TensorRT model saved to cache: {cache_path}")

return result


Expand Down Expand Up @@ -112,16 +180,31 @@ def load_and_run_model(


@typechecked
def load_and_run_model(path, data_in, device=None, use_cache=True): # pragma: no cover
def load_and_run_model(
path,
data_in,
device=None,
use_cache=True,
tensorrt_enabled: bool = False,
tensorrt_cache_dir: str = ".",
): # pragma: no cover

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = load_model(path=path, device=device, use_cache=use_cache)
model = load_model(
path=path,
device=device,
use_cache=use_cache,
input_shape=data_in.shape,
tensorrt_enabled=tensorrt_enabled,
tensorrt_cache_dir=tensorrt_cache_dir,
)

autocast_device = device.type if isinstance(device, torch.device) else str(device)
with torch.inference_mode(): # uses less memory when used with JITs
with torch.autocast(device_type=autocast_device):
output = model(tensor_ops.convert.to_torch(data_in, device=device))
output = tensor_ops.convert.astype(output, reference=data_in, cast=True)

return output
9 changes: 5 additions & 4 deletions zetta_utils/mazepa/semaphores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from zetta_utils.common.pprint import lrpad

logger = log.get_logger("mazepa")
SemaphoreType = Literal["read", "write", "cuda", "cpu"]
SemaphoreType = Literal["read", "write", "cuda", "cpu", "trt_compilation"]

DEFAULT_SEMA_COUNT = 1
TIMING_FORMAT = "dddd" # wait_time, lease_time, lease_count, start_time
Expand Down Expand Up @@ -152,14 +152,15 @@ def configure_semaphores(
Context manager for creating and destroying semaphores.
"""

sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu"]
sema_types_to_check: List[SemaphoreType] = ["read", "write", "cuda", "cpu", "trt_compilation"]
if semaphores_spec is not None:
for name in semaphores_spec:
if name not in get_args(SemaphoreType):
raise ValueError(f"`{name}` is not a valid semaphore type.")
try:
for sema_type in sema_types_to_check:
assert semaphores_spec[sema_type] >= 0
# TODO: need to make trt_compilation optional
# for sema_type in sema_types_to_check:
# assert semaphores_spec[sema_type] >= 0
semaphores_spec_ = semaphores_spec
except KeyError as e:
raise ValueError(
Expand Down
Loading