Skip to content

Commit 8b92866

Browse files
committed
Added support to module with graph break
1 parent 44b0709 commit 8b92866

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
pre_export_lowering,
3838
)
3939
from torch_tensorrt.dynamo.utils import (
40+
CPU_DEVICE,
4041
get_flat_args_with_check,
4142
get_output_metadata,
4243
parse_graph_io,
@@ -676,7 +677,6 @@ def compile(
676677

677678
gm = exported_program.module()
678679
# Move the weights in the state_dict to CPU
679-
exported_program.module().to("cpu")
680680
logger.info(
681681
"The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
682682
)
@@ -686,6 +686,7 @@ def compile(
686686
gm = post_lowering(gm, settings)
687687
logger.debug("Lowered Input graph: " + str(gm.graph))
688688

689+
exported_program.module().to(CPU_DEVICE)
689690
trt_gm = compile_module(
690691
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
691692
)
@@ -830,6 +831,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
830831
str(name),
831832
str(submodule.graph),
832833
)
834+
submodule.to(torch.cuda.current_device())
833835
continue
834836

835837
if name not in submodule_node_dict:

0 commit comments

Comments
 (0)