From a7b562365252a543fbca19bda93e96cf612bcbf8 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 30 Jan 2026 22:59:44 +0000 Subject: [PATCH] Fixed the bug caused by cpu offloading --- py/torch_tensorrt/dynamo/_compiler.py | 16 ++++++++++++---- .../runtime/_MutableTorchTensorRTModule.py | 8 ++++---- py/torch_tensorrt/dynamo/utils.py | 4 +--- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 24c009c189..8280520e38 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -373,15 +373,21 @@ def cross_compile_for_windows( gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) + # Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support. + if offload_module_to_cpu: + deallocate_module(exported_program.module()) # Apply lowering on the graph module gm = post_lowering(gm, settings) + logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") logger.debug("Lowered Input graph: " + str(gm.graph)) + # Move the weights in the state_dict to CPU if offload_module_to_cpu: - deallocate_module(exported_program.module(), delete_module=False) + deallocate_module(gm) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) + logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") else: remaining_memory, total_memory = torch.cuda.mem_get_info() if remaining_memory < total_memory // 2: @@ -766,6 +772,9 @@ def compile( # Move the weights in the state_dict to CPU logger.debug("Input graph: " + str(gm.graph)) + # Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support. + if offload_module_to_cpu: + deallocate_module(exported_program.module()) # Apply lowering on the graph module gm = post_lowering(gm, settings) logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") @@ -773,8 +782,7 @@ def compile( # Move the weights in the state_dict to CPU if offload_module_to_cpu: - deallocate_module(gm, delete_module=False) - deallocate_module(exported_program.module(), delete_module=False) + deallocate_module(gm) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) @@ -1419,7 +1427,7 @@ def convert_exported_program_to_serialized_trt_engine( # Move the weights in the state_dict to CPU if offload_module_to_cpu: - deallocate_module(exported_program.module(), delete_module=False) + deallocate_module(exported_program.module()) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index be01a37cd1..d3ef7e0a41 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -262,7 +262,7 @@ def update_refit_condition(self) -> None: args, kwargs, result = self.run_info self.original_model.to(to_torch_device(self.trt_device)) new_result = self.original_model(*args, **kwargs) - deallocate_module(self.original_model, delete_module=False) + deallocate_module(self.original_model) if check_output_equal(result, new_result): self.refit_state.set_state(RefitFlag.LIVE) return @@ -311,7 +311,7 @@ def refit_gm(self) -> None: in_place=True, ) - deallocate_module(self.original_model, delete_module=False) + deallocate_module(self.original_model) def get_exported_program(self) -> torch.export.ExportedProgram: @@ -372,7 +372,7 @@ def compile(self) -> None: **self.additional_settings, ) if self.additional_settings.get("offload_module_to_cpu", False): - deallocate_module(self.original_model, delete_module=False) + deallocate_module(self.original_model) if self.enable_weight_streaming: self.set_weight_streaming_ctx(self.weight_streaming_budget) @@ -738,7 +738,7 @@ def load(path: str) -> Any: module.exp_program = torch.export.export( module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs ) - deallocate_module(module.original_model, delete_module=False) + deallocate_module(module.original_model) cls = module.__class__ module.__class__ = type( module.original_model.__class__.__name__, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index abc697a086..54abbd6d44 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -127,14 +127,12 @@ def unified_dtype_converter( raise TypeError("%s is not a supported dtype" % dtype) -def deallocate_module(module: torch.fx.GraphModule, delete_module: bool = True) -> None: +def deallocate_module(module: torch.fx.GraphModule) -> None: """ This is a helper function to delete the instance of module. We first move it to CPU and then delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call """ module.to(CPU_DEVICE) - if delete_module: - del module torch.cuda.empty_cache() gc.collect()