diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 623bad7b9f..bd725b46cf 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections.abc +import inspect import logging import platform import warnings @@ -561,7 +562,9 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any: return dynamo_load_cross_compiled_exported_program(file_path) -def load(file_path: str = "") -> Any: +def load( + file_path: str = "", extra_files: Optional[dict[str, Any]] = None, **kwargs: Any +) -> Any: """ Load either a Torchscript model or ExportedProgram. @@ -569,13 +572,25 @@ def load(file_path: str = "") -> Any: Arguments: file_path (str): Path to file on the disk + extra_files (dict[str, Any]): Extra files to load with the model + + Example: + # Load with extra files. + extra_files = {"foo.txt": ""} # values will be replaced with serialized data + ep = torch.export.load("exported_program.pt2", extra_files=extra_files) + print(extra_files["foo.txt"]) Raises: ValueError: If there is no file or the file is not either a TorchScript file or ExportedProgram file """ try: logger.debug(f"Loading the provided file {file_path} using torch.jit.load()") - ts_module = torch.jit.load(file_path) + ts_module = function_overload_with_kwargs( + torch.jit.load, + file_path, + _extra_files=extra_files, + **kwargs, + ) return ts_module except Exception: logger.info( @@ -586,7 +601,12 @@ def load(file_path: str = "") -> Any: try: logger.debug(f"Loading the provided file {file_path} using torch.export.load()") - exp_program = torch.export.load(file_path) + exp_program = function_overload_with_kwargs( + torch.export.load, + file_path, + extra_files=extra_files, + **kwargs, + ) return exp_program except Exception: logger.info( @@ -602,6 +622,7 @@ def save( module: Any, file_path: str = "", *, + extra_files: Optional[dict[str, str]] = None, output_format: str = "exported_program", inputs: Optional[Sequence[torch.Tensor]] = None, arg_inputs: Optional[Sequence[torch.Tensor]] = None, @@ -615,6 +636,8 @@ def save( Arguments: module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module + file_path (str): Path to file on the disk + extra_files (Optional[Dict[str, Any]]): Map from filename to contents which will be stored as part of saved file. inputs (torch.Tensor): Torch input tensors arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. @@ -670,7 +693,13 @@ def save( logger.warning( "Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save." ) - torch.jit.save(module, file_path) + function_overload_with_kwargs( + torch.jit.save, + module, + file_path, + _extra_files=extra_files, + **kwargs, + ) elif module_type == _ModuleType.ep: if output_format == "torchscript": raise ValueError( @@ -682,7 +711,14 @@ def save( "Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile" ) if output_format == "exported_program": - torch.export.save(module, file_path, pickle_protocol=pickle_protocol) + function_overload_with_kwargs( + torch.export.save, + module, + file_path, + pickle_protocol=pickle_protocol, + extra_files=extra_files, + **kwargs, + ) elif output_format == "aot_inductor": inductor_configs = {} if "inductor_configs" in kwargs: @@ -703,7 +739,13 @@ def save( module_ts = torch.jit.trace( module, arg_inputs, example_kwarg_inputs=kwarg_inputs ) - torch.jit.save(module_ts, file_path) + function_overload_with_kwargs( + torch.jit.save, + module_ts, + file_path, + _extra_files=extra_files, + **kwargs, + ) else: if not retrace: from torch_tensorrt.dynamo._exporter import export @@ -714,8 +756,13 @@ def save( ) exp_program = export(module) if output_format == "exported_program": - torch.export.save( - exp_program, file_path, pickle_protocol=pickle_protocol + function_overload_with_kwargs( + torch.export.save, + exp_program, + file_path, + pickle_protocol=pickle_protocol, + extra_files=extra_files, + **kwargs, ) elif output_format == "aot_inductor": inductor_configs = {} @@ -744,8 +791,13 @@ def save( ) if output_format == "exported_program": - torch.export.save( - exp_program, file_path, pickle_protocol=pickle_protocol + function_overload_with_kwargs( + torch.export.save, + exp_program, + file_path, + pickle_protocol=pickle_protocol, + extra_files=extra_files, + **kwargs, ) elif output_format == "aot_inductor": inductor_configs = {} @@ -761,3 +813,19 @@ def save( raise RuntimeError( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" ) + + +def function_overload_with_kwargs( + fn: Callable[..., Any], *args: Any, **kwargs: Any +) -> Any: + fn_signature = inspect.signature(fn).parameters + fn_kwargs = {} + for k, v in kwargs.items(): + if k in fn_signature: + fn_kwargs[k] = v + else: + logger.warning( + f"Keyword argument {k} is not a valid argument for {fn.__name__}" + ) + + return fn(*args, **fn_kwargs) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 3209177120..bf619cc51b 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -728,3 +728,62 @@ def forward(self, x): cos_sim > COSINE_THRESHOLD, msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + + +@pytest.mark.unit +def test_save_load_extra_files(ir, tmpdir): + """ + This tests save/load API on Torchscript format (model still compiled using dynamo workflow) + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + ep_path = os.path.join(tmpdir, "trt.er") + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + trt_gm = torchtrt.compile( + model, + ir=ir, + inputs=[input], + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + assertions.assertTrue( + isinstance(trt_gm, torch.fx.GraphModule), + msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule", + ) + outputs_trt = trt_gm(input) + # Save it as torchscript representation + torchtrt.save( + trt_gm, + ep_path, + output_format="exported_program", + inputs=[input], + extra_files={"metadata": "Saving with extra files"}, + ) + + loaded_extra_files = {"metadata": None} + trt_ep_module = torchtrt.load(ep_path, extra_files=loaded_extra_files) + outputs_trt_deser = trt_ep_module.module()(input) + + cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser) + assertions.assertTrue( + loaded_extra_files["metadata"] == "Saving with extra files", + msg="Extra files not saved and loaded correctly", + ) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + )