Skip to content

Commit af6eb74

Browse files
committed
slight code reorg and bug correction for cross_compile
1 parent cb3bd17 commit af6eb74

File tree

3 files changed

+23
-29
lines changed

3 files changed

+23
-29
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ def save_cross_compiled_exported_program(
12061206

12071207
from torch_tensorrt.dynamo._exporter import export
12081208

1209-
exp_program = export(gm, cross_compile_flag=True)
1209+
exp_program = export(gm, cross_compile_module=True)
12101210
torch.export.save(exp_program, file_path)
12111211
logger.debug(f"successfully saved the module for windows at {file_path}")
12121212

py/torch_tensorrt/dynamo/_exporter.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,23 @@
2222

2323
def export(
2424
gm: torch.fx.GraphModule,
25-
cross_compile_flag: Optional[bool] = False,
25+
cross_compile_module: Optional[bool] = False,
2626
) -> ExportedProgram:
2727
"""Export the result of TensorRT compilation into the desired output format.
2828
2929
Arguments:
3030
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3131
inputs (torch.Tensor): Torch input tensors
32-
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
32+
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
3333
"""
34-
patched_module = transform(gm, cross_compile_flag)
34+
patched_module = transform(gm, cross_compile_module)
3535
exp_program = create_trt_exp_program(patched_module)
3636
return exp_program
3737

3838

3939
def transform(
4040
gm: torch.fx.GraphModule,
41-
cross_compile_flag: Optional[bool] = False,
41+
cross_compile_module: Optional[bool] = False,
4242
) -> torch.fx.GraphModule:
4343
"""
4444
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
@@ -48,7 +48,7 @@ def transform(
4848
Arguments:
4949
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
5050
inputs (torch.Tensor): Torch input tensors
51-
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
51+
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
5252
5353
Returns an inlined torch.fx.GraphModule
5454
"""
@@ -57,7 +57,7 @@ def transform(
5757
gm = copy.deepcopy(gm)
5858

5959
# Inline TensorRT submodules
60-
inline_trt_modules(gm, cross_compile_flag)
60+
inline_trt_modules(gm, cross_compile_module)
6161

6262
# Inline pytorch submodules
6363
inline_torch_modules(gm)
@@ -356,7 +356,7 @@ def create_trt_exp_program(
356356

357357

358358
def inline_trt_modules(
359-
gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False
359+
gm: torch.fx.GraphModule, cross_compile_module: Optional[bool] = False
360360
) -> torch.fx.GraphModule:
361361
"""
362362
Replace TRT submodules with trt engine nodes.
@@ -380,7 +380,16 @@ def inline_trt_modules(
380380
num_outputs = len(trt_module_node.meta["val"])
381381
# Insert a call_function node to perform inference on TRT engine
382382
with gm.graph.inserting_before(trt_module_node):
383-
if not cross_compile_flag:
383+
if cross_compile_module:
384+
engine_info = trt_module._pack_engine_info()
385+
engine_bytes = engine_info[ENGINE_IDX]
386+
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
387+
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
388+
trt_node = gm.graph.call_function(
389+
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
390+
(trt_module_node.args, *engine_info),
391+
)
392+
else:
384393
# for the normal workflow: use the execute_engine node
385394
engine_name = f"{name}_engine"
386395
setattr(gm, engine_name, trt_module.engine)
@@ -396,16 +405,6 @@ def inline_trt_modules(
396405
engine_node.meta["val"] = CustomObjArgument(
397406
name=engine_node.name, class_fqn=""
398407
)
399-
else:
400-
# for the cross compile for windows workflow: use the no_op_placeholder node
401-
engine_info = trt_module._pack_engine_info()
402-
engine_bytes = engine_info[ENGINE_IDX]
403-
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
404-
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
405-
trt_node = gm.graph.call_function(
406-
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
407-
(trt_module_node.args, *engine_info),
408-
)
409408
# set trt_node.meta with trt_module_node.meta
410409
assert num_outputs > 0
411410
trt_node.meta["val"] = trt_module_node.meta["val"]
@@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node(
464463
name=engine_node.name, class_fqn=""
465464
)
466465

467-
if len(no_op_placeholder_node.meta["val"]) == 1:
468-
with gm.graph.inserting_after(trt_node):
469-
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
470-
getitem_output.meta["val"] = trt_node.meta["val"]
471-
no_op_placeholder_node.replace_all_uses_with(getitem_output)
472-
else:
473-
no_op_placeholder_node.replace_all_uses_with(trt_node)
474-
getitem_nodes = trt_node.users
475-
for idx, getitem_node in enumerate(getitem_nodes):
476-
getitem_node.meta["val"] = trt_node.meta["val"][idx]
466+
no_op_placeholder_node.replace_all_uses_with(trt_node)
467+
getitem_nodes = trt_node.users
468+
for idx, getitem_node in enumerate(getitem_nodes):
469+
getitem_node.meta["val"] = trt_node.meta["val"][idx]
477470

478471
gm.graph.erase_node(no_op_placeholder_node)
479472

py/torch_tensorrt/runtime/_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def no_op_placeholder_for_execute_engine(
144144
serialized_hardware_compatible: str,
145145
serialized_metadata: str,
146146
serialized_target_platform: str,
147+
serialized_require_output_allocator: str,
147148
) -> List[torch.Tensor]:
148149
raise RuntimeError(
149150
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."

0 commit comments

Comments
 (0)