@@ -678,17 +678,24 @@ def compile(
678
678
)
679
679
680
680
gm = exported_program .module ()
681
- # Move the weights in the state_dict to CPU
682
681
logger .debug ("Input graph: " + str (gm .graph ))
683
682
684
683
# Apply lowering on the graph module
685
684
gm = post_lowering (gm , settings )
686
685
logger .debug ("Lowered Input graph: " + str (gm .graph ))
686
+
687
+ # Move the weights in the state_dict to CPU
687
688
if offload_module_to_cpu :
688
689
exported_program .module ().to (CPU_DEVICE )
689
690
logger .info (
690
- "The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False. "
691
+ "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
691
692
)
693
+ else :
694
+ remaining_memory , total_memory = torch .cuda .mem_get_info ()
695
+ if remaining_memory < total_memory / 2 :
696
+ logger .warning (
697
+ "The remaining GPU memory is not enough to compile the model. This may lead to an OOM error. Consider setting offload_module_to_cpu=True."
698
+ )
692
699
trt_gm = compile_module (
693
700
gm , trt_arg_inputs , trt_kwarg_inputs , settings , engine_cache
694
701
)
@@ -833,7 +840,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
833
840
str (name ),
834
841
str (submodule .graph ),
835
842
)
836
- submodule .to (torch . cuda . current_device ( ))
843
+ submodule .to (to_torch_device ( settings . device ))
837
844
continue
838
845
839
846
if name not in submodule_node_dict :
0 commit comments