Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ void TRTEngine::enable_profiling() {
exec_ctx->setProfiler(trt_engine_profiler.get());
}

void TRTEngine::set_output_tensors_as_unowned(bool enable) {
this->output_tensors_are_unowned = enable;
}

bool TRTEngine::are_output_tensors_unowned() {
return this->output_tensors_are_unowned;
}

void TRTEngine::set_profile_format(std::string format) {
if (format == "trex") {
this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX);
Expand Down
3 changes: 3 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
std::pair<uint64_t, uint64_t> num_io;
bool output_tensors_are_unowned = false;
std::string name;
RTDevice device_info;

Expand Down Expand Up @@ -159,6 +160,8 @@ struct TRTEngine : torch::CustomClassHolder {
int64_t get_automatic_device_memory_budget();
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
void set_pre_allocated_outputs(bool enable);
void set_output_tensors_as_unowned(bool enable);
bool are_output_tensors_unowned();
TorchTRTRuntimeStates runtime_states;
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';
Expand Down
7 changes: 5 additions & 2 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
} // End engine exeuction (resets to caller stream)

// Create output buffer for next execution of graph or trt context.
if (compiled_engine->use_pre_allocated_outputs) {
// When the pre-allocated output mode is turned on, for intermediate modules, we only create the output in the first
// execution or when shape is changed.
if (compiled_engine->use_pre_allocated_outputs &&
(compiled_engine->pre_allocated_outputs.size() == 0 || compiled_engine->output_tensors_are_unowned ||
shape_changed)) {
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
}

Expand Down
2 changes: 2 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned)
.def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned)
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
.def_property(
Expand Down
1 change: 0 additions & 1 deletion examples/dynamo/pre_allocated_output_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_module_perf(model, *input):
optimized_model = torch_tensorrt.compile(
model,
ir="dynamo",
enabled_precisions={torch.half},
inputs=inputs,
)

Expand Down
15 changes: 13 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def compile(
stacklevel=2,
)

if kwargs.get("use_explicit_typing", False) == False:
if not kwargs.get("use_explicit_typing", False):
warnings.warn(
"`use_explicit_typing` is deprecated. This setting will be removed and you should enable autocast instead.",
DeprecationWarning,
Expand Down Expand Up @@ -1070,14 +1070,25 @@ def preserve_module_specs(
) as f:
f.write(trt_module.get_layer_info())

# Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor

# Parse the graph I/O and store it in dryrun tracker
parse_graph_io(gm, dryrun_tracker)

# Replace all FX Modules with TRT Modules
for name, trt_module in trt_modules.items():
setattr(partitioned_module, name, trt_module)
if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows:
getattr(partitioned_module, name).setup_engine()
trt_module = getattr(partitioned_module, name)
trt_module.setup_engine()

output_node = list(partitioned_module.graph.nodes)[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ordering gaurenteed, wouldnt it be more correct backtracing from graph outputs?

for arg in output_node.args:
for output in arg:
target = output.target
if "_run_on_acc" not in str(target):
continue
getattr(partitioned_module, target).set_output_tensors_as_unowned(True)

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
Expand Down
50 changes: 40 additions & 10 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you are just renaming variables here? Why do we need to make that change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pre_allocated_outputs -> allocated_outputs change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also changed self.output_tensors = self.create_output_tensors() to self.allocated_outputs = self.create_output_tensors(). Basically merged two variable together. Isn't that what you wanted me to do?

Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,27 @@ def __init__(
self.requires_output_allocator = requires_output_allocator
self.output_allocator: Optional[DynamicOutputAllocator] = None
self.use_output_allocator_outputs = False

self.device = torch.cuda.current_device()
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
# If the output tensor is not owned by the engine (output_tensors_are_unowned=True), we need to create a new output tensor in each forward pass
self.output_tensors_are_unowned = False
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

def set_output_tensors_as_unowned(self, enabled: bool) -> None:
"""
Flag to set if the output tensors of this engine are solely owned by the Torch-TensorRT Runtime or if they might be shared with a user.
If the tensors are not owned by the runtime, then they must be recreated on every forward call which may have implications for performance.
Typically only the final engine in a graph requires output tensors to be unowned and there are performance gains to be had for intermediate engines to manage their own standing memory.
Therefore this should only be set to True for the final module in a graph and leave false for intermediate modules.

Args:
enabled: bool
Whether to set the flag to True.

"""
self.output_tensors_are_unowned = enabled

def get_streamable_device_memory_budget(self) -> Any:
return self.engine.streamable_weights_size

Expand Down Expand Up @@ -298,6 +315,11 @@ def setup_engine(self) -> None:
if torch_tensorrt.runtime.get_cudagraphs_mode():
self.cudagraph = torch.cuda.CUDAGraph()

self.is_shape_inference_io = {
input_name: self.engine.is_shape_inference_io(input_name)
for input_name in self.input_names
}

def _check_initialized(self) -> None:
if not self.initialized:
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
Expand Down Expand Up @@ -383,7 +405,7 @@ def setup_input_tensors(

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
if self.is_shape_inference_io[input_name]:
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
Expand Down Expand Up @@ -411,7 +433,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
output = torch.empty(
size=self.output_shapes[o],
dtype=self.output_dtypes[o],
device=torch.cuda.current_device(),
device=self.device,
)
outputs.append(output)
return outputs
Expand Down Expand Up @@ -548,7 +570,12 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:

self._caller_stream.wait_stream(self._engine_stream)

if self.use_pre_allocated_outputs:
# When the pre-allocated output mode is turned on, for intermediate modules, we only create the output in the first execution or when shape is changed.
if self.use_pre_allocated_outputs and (
self.output_tensors_are_unowned
or not self.pre_allocated_outputs
or shape_changed
):
self.pre_allocated_outputs = self.create_output_tensors()

if self.cudagraphs_enabled:
Expand Down Expand Up @@ -751,13 +778,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
tensor_inputs = []
for t in inputs:
if not isinstance(t, torch.Tensor):
return True
tensor_inputs.append(t)
if not all(isinstance(t, torch.Tensor) for t in inputs):
return True

new_shape_key = "".join(
str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
str(tuple(t.shape)).replace(" ", "")
for t in inputs
if isinstance(t, torch.Tensor)
)

# If the new shape key differs from the existing one,
Expand All @@ -768,3 +795,6 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
return True

return False

def are_output_tensors_unowned(self) -> bool:
return self.output_tensors_are_unowned
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ def _pack_engine_info(self) -> List[str | bytes]:
metadata = {
"settings": self.settings,
"weight_name_map": self.weight_name_map,
"output_tensors_are_unowned": (
False
if self.engine is None
else self.engine.are_output_tensors_unowned()
),
}
target_platform = (
Platform.current_platform()
Expand Down Expand Up @@ -284,6 +289,8 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
metadata = TorchTensorRTModule.decode_metadata(serialized_metadata)
self.settings = metadata["settings"]
self.weight_name_map = metadata["weight_name_map"]
self.output_tensors_are_unowned = metadata["output_tensors_are_unowned"]
self.engine.set_output_tensors_as_unowned(self.output_tensors_are_unowned)

else:
self.engine = None
Expand Down Expand Up @@ -355,6 +362,12 @@ def enable_profiling(
self.engine.enable_profiling()
self.engine.set_profile_format(profile_format)

def set_output_tensors_as_unowned(self, enabled: bool) -> None:
self.engine.set_output_tensors_as_unowned(enabled)

def are_output_tensors_unowned(self) -> bool:
return self.engine.are_output_tensors_unowned() # type: ignore[no-any-return]

def disable_profiling(self) -> None:
"""Disable the profiler"""
if self.engine is None:
Expand Down
Loading
Loading