Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,21 @@ def cross_compile_for_windows(
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support.
if offload_module_to_cpu:
deallocate_module(exported_program.module())
# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(exported_program.module(), delete_module=False)
deallocate_module(gm)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
Expand Down Expand Up @@ -766,15 +772,17 @@ def compile(
# Move the weights in the state_dict to CPU
logger.debug("Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support.
if offload_module_to_cpu:
deallocate_module(exported_program.module())
# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(gm, delete_module=False)
deallocate_module(exported_program.module(), delete_module=False)
deallocate_module(gm)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
Expand Down Expand Up @@ -1419,7 +1427,7 @@ def convert_exported_program_to_serialized_trt_engine(

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(exported_program.module(), delete_module=False)
deallocate_module(exported_program.module())
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def update_refit_condition(self) -> None:
args, kwargs, result = self.run_info
self.original_model.to(to_torch_device(self.trt_device))
new_result = self.original_model(*args, **kwargs)
deallocate_module(self.original_model, delete_module=False)
deallocate_module(self.original_model)
if check_output_equal(result, new_result):
self.refit_state.set_state(RefitFlag.LIVE)
return
Expand Down Expand Up @@ -311,7 +311,7 @@ def refit_gm(self) -> None:
in_place=True,
)

deallocate_module(self.original_model, delete_module=False)
deallocate_module(self.original_model)

def get_exported_program(self) -> torch.export.ExportedProgram:

Expand Down Expand Up @@ -372,7 +372,7 @@ def compile(self) -> None:
**self.additional_settings,
)
if self.additional_settings.get("offload_module_to_cpu", False):
deallocate_module(self.original_model, delete_module=False)
deallocate_module(self.original_model)
if self.enable_weight_streaming:
self.set_weight_streaming_ctx(self.weight_streaming_budget)

Expand Down Expand Up @@ -738,7 +738,7 @@ def load(path: str) -> Any:
module.exp_program = torch.export.export(
module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs
)
deallocate_module(module.original_model, delete_module=False)
deallocate_module(module.original_model)
cls = module.__class__
module.__class__ = type(
module.original_model.__class__.__name__,
Expand Down
4 changes: 1 addition & 3 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,12 @@ def unified_dtype_converter(
raise TypeError("%s is not a supported dtype" % dtype)


def deallocate_module(module: torch.fx.GraphModule, delete_module: bool = True) -> None:
def deallocate_module(module: torch.fx.GraphModule) -> None:
"""
This is a helper function to delete the instance of module. We first move it to CPU and then
delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call
"""
module.to(CPU_DEVICE)
if delete_module:
del module
torch.cuda.empty_cache()
gc.collect()

Expand Down
Loading