Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 78 additions & 10 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections.abc
import inspect
import logging
import platform
import warnings
Expand Down Expand Up @@ -561,21 +562,35 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any:
return dynamo_load_cross_compiled_exported_program(file_path)


def load(file_path: str = "") -> Any:
def load(
file_path: str = "", extra_files: Optional[dict[str, Any]] = None, **kwargs: Any
) -> Any:
"""
Load either a Torchscript model or ExportedProgram.

Loads a TorchScript or ExportedProgram file from disk. File type will be detect the type using try, except.

Arguments:
file_path (str): Path to file on the disk
extra_files (dict[str, Any]): Extra files to load with the model

Example:
# Load with extra files.
extra_files = {"foo.txt": ""} # values will be replaced with serialized data
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files["foo.txt"])

Raises:
ValueError: If there is no file or the file is not either a TorchScript file or ExportedProgram file
"""
try:
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
ts_module = torch.jit.load(file_path)
ts_module = function_overload_with_kwargs(
torch.jit.load,
file_path,
_extra_files=extra_files,
**kwargs,
)
return ts_module
except Exception:
logger.info(
Expand All @@ -586,7 +601,12 @@ def load(file_path: str = "") -> Any:

try:
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
exp_program = torch.export.load(file_path)
exp_program = function_overload_with_kwargs(
torch.export.load,
file_path,
extra_files=extra_files,
**kwargs,
)
return exp_program
except Exception:
logger.info(
Expand All @@ -602,6 +622,7 @@ def save(
module: Any,
file_path: str = "",
*,
extra_files: Optional[dict[str, str]] = None,
output_format: str = "exported_program",
inputs: Optional[Sequence[torch.Tensor]] = None,
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
Expand All @@ -615,6 +636,8 @@ def save(

Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
file_path (str): Path to file on the disk
extra_files (Optional[Dict[str, Any]]): Map from filename to contents which will be stored as part of saved file.
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
Expand Down Expand Up @@ -670,7 +693,13 @@ def save(
logger.warning(
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
)
torch.jit.save(module, file_path)
function_overload_with_kwargs(
torch.jit.save,
module,
file_path,
_extra_files=extra_files,
**kwargs,
)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
Expand All @@ -682,7 +711,14 @@ def save(
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
)
if output_format == "exported_program":
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
function_overload_with_kwargs(
torch.export.save,
module,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
**kwargs,
)
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
Expand All @@ -703,7 +739,13 @@ def save(
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
)
torch.jit.save(module_ts, file_path)
function_overload_with_kwargs(
torch.jit.save,
module_ts,
file_path,
_extra_files=extra_files,
**kwargs,
)
else:
if not retrace:
from torch_tensorrt.dynamo._exporter import export
Expand All @@ -714,8 +756,13 @@ def save(
)
exp_program = export(module)
if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
function_overload_with_kwargs(
torch.export.save,
exp_program,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
**kwargs,
)
elif output_format == "aot_inductor":
inductor_configs = {}
Expand Down Expand Up @@ -744,8 +791,13 @@ def save(
)

if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
function_overload_with_kwargs(
torch.export.save,
exp_program,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
**kwargs,
)
elif output_format == "aot_inductor":
inductor_configs = {}
Expand All @@ -761,3 +813,19 @@ def save(
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
)


def function_overload_with_kwargs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can uses functools partials for this, might be a bit more stable

fn: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
fn_signature = inspect.signature(fn).parameters
fn_kwargs = {}
for k, v in kwargs.items():
if k in fn_signature:
fn_kwargs[k] = v
else:
logger.warning(
f"Keyword argument {k} is not a valid argument for {fn.__name__}"
)

return fn(*args, **fn_kwargs)
59 changes: 59 additions & 0 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,62 @@ def forward(self, x):
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_save_load_extra_files(ir, tmpdir):
"""
This tests save/load API on Torchscript format (model still compiled using dynamo workflow)
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
conv = self.conv(x)
relu = self.relu(conv)
mul = relu * 0.5
return mul

ep_path = os.path.join(tmpdir, "trt.er")
model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

trt_gm = torchtrt.compile(
model,
ir=ir,
inputs=[input],
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)
assertions.assertTrue(
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
)
outputs_trt = trt_gm(input)
# Save it as torchscript representation
torchtrt.save(
trt_gm,
ep_path,
output_format="exported_program",
inputs=[input],
extra_files={"metadata": "Saving with extra files"},
)

loaded_extra_files = {"metadata": None}
trt_ep_module = torchtrt.load(ep_path, extra_files=loaded_extra_files)
outputs_trt_deser = trt_ep_module.module()(input)

cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
assertions.assertTrue(
loaded_extra_files["metadata"] == "Saving with extra files",
msg="Extra files not saved and loaded correctly",
)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
Loading