Skip to content

Commit cef66f4

Browse files
committed
Fixed comments
1 parent 76cab94 commit cef66f4

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

Diff for: py/torch_tensorrt/dynamo/_compiler.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -678,17 +678,24 @@ def compile(
678678
)
679679

680680
gm = exported_program.module()
681-
# Move the weights in the state_dict to CPU
682681
logger.debug("Input graph: " + str(gm.graph))
683682

684683
# Apply lowering on the graph module
685684
gm = post_lowering(gm, settings)
686685
logger.debug("Lowered Input graph: " + str(gm.graph))
686+
687+
# Move the weights in the state_dict to CPU
687688
if offload_module_to_cpu:
688689
exported_program.module().to(CPU_DEVICE)
689690
logger.info(
690-
"The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False."
691+
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
691692
)
693+
else:
694+
remaining_memory, total_memory = torch.cuda.mem_get_info()
695+
if remaining_memory < total_memory / 2:
696+
logger.warning(
697+
"The remaining GPU memory is not enough to compile the model. This may lead to an OOM error. Consider setting offload_module_to_cpu=True."
698+
)
692699
trt_gm = compile_module(
693700
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
694701
)
@@ -833,7 +840,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
833840
str(name),
834841
str(submodule.graph),
835842
)
836-
submodule.to(torch.cuda.current_device())
843+
submodule.to(to_torch_device(settings.device))
837844
continue
838845

839846
if name not in submodule_node_dict:

0 commit comments

Comments
 (0)