From bcc4f7cfe598c3cdf03461053e7ac9d3a0bff4ba Mon Sep 17 00:00:00 2001 From: cehongwang Date: Sun, 4 Jan 2026 23:51:36 +0000 Subject: [PATCH 1/3] Squashed and fixed torchscript CI --- core/runtime/TRTEngine.cpp | 8 + core/runtime/TRTEngine.h | 3 + core/runtime/execute_engine.cpp | 4 +- core/runtime/register_jit_hooks.cpp | 2 + .../dynamo/pre_allocated_output_example.py | 1 - py/torch_tensorrt/dynamo/_compiler.py | 14 +- .../runtime/_PythonTorchTensorRTModule.py | 49 ++++-- .../dynamo/runtime/_TorchTensorRTModule.py | 13 ++ .../runtime/test_pre_allocated_outputs.py | 38 +++++ .../perf/graph_break_overhead/graph_break.py | 150 ++++++++++++++++++ 10 files changed, 268 insertions(+), 14 deletions(-) create mode 100644 tools/perf/graph_break_overhead/graph_break.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 52a9b47c12..a86ca1dbf6 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -281,6 +281,14 @@ void TRTEngine::enable_profiling() { exec_ctx->setProfiler(trt_engine_profiler.get()); } +void TRTEngine::set_output_tensors_as_unowned(bool enable) { + this->output_tensors_are_unowned = enable; +} + +bool TRTEngine::are_output_tensors_unowned() { + return this->output_tensors_are_unowned; +} + void TRTEngine::set_profile_format(std::string format) { if (format == "trex") { this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 15d723ce4e..7e0de52126 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -103,6 +103,7 @@ struct TRTEngine : torch::CustomClassHolder { std::shared_ptr cuda_engine; std::shared_ptr exec_ctx; std::pair num_io; + bool output_tensors_are_unowned = false; std::string name; RTDevice device_info; @@ -159,6 +160,8 @@ struct TRTEngine : torch::CustomClassHolder { int64_t get_automatic_device_memory_budget(); std::vector infer_outputs(std::vector> input_shapes); void set_pre_allocated_outputs(bool enable); + void set_output_tensors_as_unowned(bool enable); + bool are_output_tensors_unowned(); TorchTRTRuntimeStates runtime_states; friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); static const char BINDING_DELIM = '%'; diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..bfc2e73670 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -321,7 +321,9 @@ std::vector execute_engine(std::vector inputs, c10::intr } // End engine exeuction (resets to caller stream) // Create output buffer for next execution of graph or trt context. - if (compiled_engine->use_pre_allocated_outputs) { + if (compiled_engine->use_pre_allocated_outputs && + (compiled_engine->pre_allocated_outputs.size() == 0 || compiled_engine->output_tensors_are_unowned || + shape_changed)) { compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); } diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 173ff8c35f..49cd12f86a 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -90,6 +90,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) .def("reset_captured_graph", &TRTEngine::reset_captured_graph) + .def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned) + .def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( diff --git a/examples/dynamo/pre_allocated_output_example.py b/examples/dynamo/pre_allocated_output_example.py index 2ad1b8f514..f96d796cc9 100644 --- a/examples/dynamo/pre_allocated_output_example.py +++ b/examples/dynamo/pre_allocated_output_example.py @@ -79,7 +79,6 @@ def test_module_perf(model, *input): optimized_model = torch_tensorrt.compile( model, ir="dynamo", - enabled_precisions={torch.half}, inputs=inputs, ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 120855c5bb..67be775d3d 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -557,7 +557,7 @@ def compile( stacklevel=2, ) - if kwargs.get("use_explicit_typing", False) == False: + if not kwargs.get("use_explicit_typing", False): warnings.warn( "`use_explicit_typing` is deprecated. This setting will be removed and you should enable autocast instead.", DeprecationWarning, @@ -1070,6 +1070,8 @@ def preserve_module_specs( ) as f: f.write(trt_module.get_layer_info()) + # Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) @@ -1077,7 +1079,15 @@ def preserve_module_specs( for name, trt_module in trt_modules.items(): setattr(partitioned_module, name, trt_module) if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows: - getattr(partitioned_module, name).setup_engine() + trt_module = getattr(partitioned_module, name) + trt_module.setup_engine() + + output_node = list(partitioned_module.graph.nodes)[-1] + for arg in output_node.args: + target = arg[0].target + if "_run_on_acc" not in str(target): + continue + getattr(partitioned_module, target).set_output_tensors_as_unowned(True) # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index a946f38761..b04af70267 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -218,10 +218,27 @@ def __init__( self.requires_output_allocator = requires_output_allocator self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - + self.device = torch.cuda.current_device() + self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + # If the output tensor is not owned by the engine (output_tensors_are_unowned=True), we need to create a new output tensor in each forward pass + self.output_tensors_are_unowned = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() + def set_output_tensors_as_unowned(self, enabled: bool) -> None: + """ + Flag to set if the output tensors of this engine are solely owned by the Torch-TensorRT Runtime or if they might be shared with a user. + If the tensors are not owned by the runtime, then they must be recreated on every forward call which may have implications for performance. + Typically only the final engine in a graph requires output tensors to be unowned and there are performance gains to be had for intermediate engines to manage their own standing memory. + Therefore this should only be set to True for the final module in a graph and leave false for intermediate modules. + + Args: + enabled: bool + Whether to set the flag to True. + + """ + self.output_tensors_are_unowned = enabled + def get_streamable_device_memory_budget(self) -> Any: return self.engine.streamable_weights_size @@ -298,6 +315,11 @@ def setup_engine(self) -> None: if torch_tensorrt.runtime.get_cudagraphs_mode(): self.cudagraph = torch.cuda.CUDAGraph() + self.is_shape_inference_io = { + input_name: self.engine.is_shape_inference_io(input_name) + for input_name in self.input_names + } + def _check_initialized(self) -> None: if not self.initialized: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") @@ -383,7 +405,7 @@ def setup_input_tensors( # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers # as per TensorRT requirements - if self.engine.is_shape_inference_io(input_name): + if self.is_shape_inference_io[input_name]: # Shape tensor inputs are casted to int64 explicitly # Currently Torch CPU pointers are not working; numpy pointers are used instead # to refer to underlying memory @@ -411,7 +433,7 @@ def create_output_tensors(self) -> List[torch.Tensor]: output = torch.empty( size=self.output_shapes[o], dtype=self.output_dtypes[o], - device=torch.cuda.current_device(), + device=self.device, ) outputs.append(output) return outputs @@ -548,7 +570,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: self._caller_stream.wait_stream(self._engine_stream) - if self.use_pre_allocated_outputs: + if self.use_pre_allocated_outputs and ( + self.output_tensors_are_unowned + or not self.pre_allocated_outputs + or shape_changed + ): self.pre_allocated_outputs = self.create_output_tensors() if self.cudagraphs_enabled: @@ -751,13 +777,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - tensor_inputs = [] - for t in inputs: - if not isinstance(t, torch.Tensor): - return True - tensor_inputs.append(t) + if not all(isinstance(t, torch.Tensor) for t in inputs): + return True + new_shape_key = "".join( - str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs + str(tuple(t.shape)).replace(" ", "") + for t in inputs + if isinstance(t, torch.Tensor) ) # If the new shape key differs from the existing one, @@ -768,3 +794,6 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: return True return False + + def are_output_tensors_unowned(self) -> bool: + return self.output_tensors_are_unowned diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..23c372167d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -156,6 +156,11 @@ def _pack_engine_info(self) -> List[str | bytes]: metadata = { "settings": self.settings, "weight_name_map": self.weight_name_map, + "output_tensors_are_unowned": ( + False + if self.engine is None + else self.engine.are_output_tensors_unowned() + ), } target_platform = ( Platform.current_platform() @@ -284,6 +289,8 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] + self.output_tensors_are_unowned = metadata["output_tensors_are_unowned"] + self.engine.set_output_tensors_as_unowned(self.output_tensors_are_unowned) else: self.engine = None @@ -355,6 +362,12 @@ def enable_profiling( self.engine.enable_profiling() self.engine.set_profile_format(profile_format) + def set_output_tensors_as_unowned(self, enabled: bool) -> None: + self.engine.set_output_tensors_as_unowned(enabled) + + def are_output_tensors_unowned(self) -> bool: + return self.engine.are_output_tensors_unowned() # type: ignore[no-any-return] + def disable_profiling(self) -> None: """Disable the profiler""" if self.engine is None: diff --git a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py index b8c7b61fb3..35dab61161 100644 --- a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py +++ b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py @@ -125,6 +125,44 @@ def forward(self, x): ) torch._dynamo.reset() + def test_pre_allocated_outputs_unowned_outputs(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax(x * 7 + 2, dim=0) + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + optimized_model(inputs[0]) + output_tensors = [ + trt_mod.pre_allocated_outputs + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + optimized_model(inputs[0]) + new_output_tensors = [ + trt_mod.pre_allocated_outputs + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + self.assertTrue(output_tensors[0] is new_output_tensors[0]) + self.assertTrue(output_tensors[1] is not new_output_tensors[1]) + + torch._dynamo.reset() + if __name__ == "__main__": run_tests() diff --git a/tools/perf/graph_break_overhead/graph_break.py b/tools/perf/graph_break_overhead/graph_break.py new file mode 100644 index 0000000000..47d467d35b --- /dev/null +++ b/tools/perf/graph_break_overhead/graph_break.py @@ -0,0 +1,150 @@ +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +import torchvision +from pyinstrument import Profiler +from torch_tensorrt.dynamo.utils import get_model_device + +torch.manual_seed(0) +torch.cuda.manual_seed_all(0) +import argparse + + +def benchmark_model(model, input, label, profile=False): + if profile: + profiler = Profiler(interval=0.01) + profiler.start() + start_time = time.time() + for _ in range(1000): + model_outputs = model(*input) + end_time = time.time() + print(f"{label} 1000 runs: {end_time - start_time:.4f} seconds") + if profile: + profiler.stop() + profiler.write_html( + f"/home/other/{label.replace(' ', '_')}.html", timeline=False, show_all=True + ) + + +def main(args): + profile = args.profile + use_python_runtime = args.use_python_runtime + model_name = args.model + + with torchtrt.dynamo.Debugger(log_level="debug", engine_builder_monitor=False): + + model = ( + torchvision.models.__dict__[model_name](pretrained=True).eval().to("cuda") + ) + input = [torch.randn((1, 3, 224, 224)).to("cuda")] + + BATCH = torch.export.Dim("BATCH", min=1, max=16) + exp_program = torch.export.export(model, tuple(input), strict=True) + trt_mod2 = trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(input), + use_python_runtime=use_python_runtime, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + reuse_cached_engines=False, + ) + + trt_mod1 = trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(input), + use_python_runtime=use_python_runtime, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + torch_executed_ops={torch.ops.aten.relu.default}, + reuse_cached_engines=False, + ) + + # AOTI + if not use_python_runtime: + torchtrt.save( + trt_mod1, + "/home/other/aoti.pt2", + output_format="aot_inductor", + inputs=input, + retrace=True, + ) + aoti_model_gb = torch._inductor.aoti_load_package("/home/other/aoti.pt2") + torchtrt.save( + trt_mod2, + "/home/other/aoti_no_gb.pt2", + output_format="aot_inductor", + inputs=input, + retrace=True, + ) + aoti_model_no_gb = torch._inductor.aoti_load_package( + "/home/other/aoti_no_gb.pt2" + ) + + # Warmup runs to avoid measuring first-run overheads + for _ in range(100): + trt_mod2(*input) + model(*input) + if not use_python_runtime: + aoti_model_gb(*input) + aoti_model_no_gb(*input) + + time.sleep(1) + benchmark_model(trt_mod1, input, "trt_mod1 (with graph break)", profile=profile) + benchmark_model(trt_mod2, input, "trt_mod2 (without graph break)", profile=profile) + if not use_python_runtime: + benchmark_model(aoti_model_gb, input, "aoti_model_gb", profile=profile) + benchmark_model(aoti_model_no_gb, input, "aoti_model_no_gb", profile=profile) + + out1 = trt_mod1(*input) + out2 = trt_mod2(*input) + if not use_python_runtime: + out3 = aoti_model_gb(*input) + out4 = aoti_model_no_gb(*input) + + def _to_tuple(x): + if isinstance(x, (tuple, list)): + return tuple(x) + return (x,) + + outs1 = _to_tuple(out1) + outs2 = _to_tuple(out2) + if not use_python_runtime: + outs3 = _to_tuple(out3) + outs4 = _to_tuple(out4) + + def compare_outputs(a, b, name1="A", name2="B"): + if len(a) != len(b): + print(f"Number of outputs differ: {len(a)} vs {len(b)}") + return False + all_equal = True + for i, (x, y) in enumerate(zip(a, b)): + if not torch.allclose(x, y, atol=1e-3, rtol=1e-3): + print(f"Output {i} differs between {name1} and {name2}") + print(f"max diff: {torch.max(torch.abs(x - y))}") + print(f"Mean diff: {torch.mean(torch.abs(x - y))}") + all_equal = False + if all_equal: + print(f"All outputs match between {name1} and {name2}") + return all_equal + + compare_outputs(outs1, outs2, "trt_mod1", "trt_mod2") + if not use_python_runtime: + compare_outputs(outs1, outs3, "trt_mod1", "aoti_model_gb") + compare_outputs(outs1, outs4, "trt_mod1", "aoti_model_no_gb") + compare_outputs(outs2, outs3, "trt_mod2", "aoti_model") + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--profile", action="store_true") + arg_parser.add_argument("--use_python_runtime", action="store_true") + arg_parser.add_argument( + "--model", type=str, default="resnet18", choices=["resnet18", "resnet152"] + ) + args = arg_parser.parse_args() + main(args) From 99660e6b95c1ac2b93cf3bc3188f49b4d7035a23 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 5 Jan 2026 08:40:01 +0000 Subject: [PATCH 2/3] Adde explanation --- core/runtime/execute_engine.cpp | 3 ++- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index bfc2e73670..54e9701c9e 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -320,7 +320,8 @@ std::vector execute_engine(std::vector inputs, c10::intr } } // End engine exeuction (resets to caller stream) - // Create output buffer for next execution of graph or trt context. + // When the pre-allocated output mode is turned on, for intermediate modules, we only create the output in the first + // execution or when shape is changed. if (compiled_engine->use_pre_allocated_outputs && (compiled_engine->pre_allocated_outputs.size() == 0 || compiled_engine->output_tensors_are_unowned || shape_changed)) { diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index b04af70267..00b5224740 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -570,6 +570,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: self._caller_stream.wait_stream(self._engine_stream) + # When the pre-allocated output mode is turned on, for intermediate modules, we only create the output in the first execution or when shape is changed. if self.use_pre_allocated_outputs and ( self.output_tensors_are_unowned or not self.pre_allocated_outputs From 0bef55bc7fa99c23d07c9983dcc074d7605bcf7e Mon Sep 17 00:00:00 2001 From: Naren Dasan <1790613+narendasan@users.noreply.github.com> Date: Mon, 5 Jan 2026 18:05:35 -0700 Subject: [PATCH 3/3] tests: Adding additional test cases for the unowned tensor feature (#3993) Co-authored-by: cehongwang --- py/torch_tensorrt/dynamo/_compiler.py | 9 +- .../runtime/test_pre_allocated_outputs.py | 241 +++++++++++++++++- 2 files changed, 243 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 67be775d3d..cbde956a88 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1084,10 +1084,11 @@ def preserve_module_specs( output_node = list(partitioned_module.graph.nodes)[-1] for arg in output_node.args: - target = arg[0].target - if "_run_on_acc" not in str(target): - continue - getattr(partitioned_module, target).set_output_tensors_as_unowned(True) + for output in arg: + target = output.target + if "_run_on_acc" not in str(target): + continue + getattr(partitioned_module, target).set_output_tensors_as_unowned(True) # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: diff --git a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py index 35dab61161..a9f8cfbbe5 100644 --- a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py +++ b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py @@ -125,7 +125,7 @@ def forward(self, x): ) torch._dynamo.reset() - def test_pre_allocated_outputs_unowned_outputs(self): + def test_pre_allocated_outputs_unowned_outputs_py_api_check_no_realloc(self): class SampleModel(torch.nn.Module): def forward(self, x): return torch.softmax(x * 7 + 2, dim=0) @@ -146,21 +146,256 @@ def forward(self, x): ) with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): - optimized_model(inputs[0]) + _ = optimized_model(inputs[0]) output_tensors = [ trt_mod.pre_allocated_outputs for name, trt_mod in optimized_model.named_children() if "_run_on_acc" in name ] - optimized_model(inputs[0]) + _ = optimized_model(inputs[0]) new_output_tensors = [ trt_mod.pre_allocated_outputs for name, trt_mod in optimized_model.named_children() if "_run_on_acc" in name ] + + # Run to run, output of intermediate engine is not reallocated self.assertTrue(output_tensors[0] is new_output_tensors[0]) + # Run to run, output of output engine is reallocated self.assertTrue(output_tensors[1] is not new_output_tensors[1]) + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_api_check( + self, _, use_python_runtime + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax(x * 7 + 2, dim=0) + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + _ = optimized_model(inputs[0]) + self.assertTrue( + all( + seen == expected + for seen, expected in zip( + [ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned(), + ], + [False, True], + ) + ) + ) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax(x * 7 + 2, dim=0) + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + torch_res = model(inputs[0]) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res_1 = optimized_model(inputs[0]) + res_2 = optimized_model(inputs[0]) + + # Results are correct + torch.testing.assert_close( + torch_res, + res_1, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + + # Results between runs are identical + torch.testing.assert_close( + res_1, + res_2, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + + torch._dynamo.reset() + + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_py_api_check_no_realloc( + self, + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res1 = optimized_model(inputs[0]) + output_tensors = [ + [t.data_ptr() for t in trt_mod.pre_allocated_outputs] + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + + _ = optimized_model(inputs[0]) + new_output_tensors = [ + [t.data_ptr() for t in trt_mod.pre_allocated_outputs] + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + + # Run to run, output of intermediate engine is reallocated + self.assertTrue(output_tensors[0] != new_output_tensors[0]) + # Run to run, output of output engine is reallocated + self.assertTrue(output_tensors[1] != new_output_tensors[1]) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_api_check( + self, _, use_python_runtime + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + _ = optimized_model(inputs[0]) + self.assertTrue( + all( + seen == expected + for seen, expected in zip( + [ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned(), + ], + [True, True], + ) + ) + ) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_multi_outputs( + self, _, use_python_runtime + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res_1 = optimized_model(inputs[0]) + res_2 = optimized_model(inputs[0]) + + torch.testing.assert_close( + res_1, + res_2, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + torch._dynamo.reset()