Skip to content

Commit dc36709

Browse files
authored
slight code reorg and bug correction for cross_compile (#3472)
1 parent 325c83e commit dc36709

File tree

5 files changed

+71
-49
lines changed

5 files changed

+71
-49
lines changed

Diff for: py/torch_tensorrt/dynamo/_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ def save_cross_compiled_exported_program(
12061206

12071207
from torch_tensorrt.dynamo._exporter import export
12081208

1209-
exp_program = export(gm, cross_compile_flag=True)
1209+
exp_program = export(gm, cross_compile_module=True)
12101210
torch.export.save(exp_program, file_path)
12111211
logger.debug(f"successfully saved the module for windows at {file_path}")
12121212

Diff for: py/torch_tensorrt/dynamo/_exporter.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,23 @@
2222

2323
def export(
2424
gm: torch.fx.GraphModule,
25-
cross_compile_flag: Optional[bool] = False,
25+
cross_compile_module: Optional[bool] = False,
2626
) -> ExportedProgram:
2727
"""Export the result of TensorRT compilation into the desired output format.
2828
2929
Arguments:
3030
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3131
inputs (torch.Tensor): Torch input tensors
32-
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
32+
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
3333
"""
34-
patched_module = transform(gm, cross_compile_flag)
34+
patched_module = transform(gm, cross_compile_module)
3535
exp_program = create_trt_exp_program(patched_module)
3636
return exp_program
3737

3838

3939
def transform(
4040
gm: torch.fx.GraphModule,
41-
cross_compile_flag: Optional[bool] = False,
41+
cross_compile_module: Optional[bool] = False,
4242
) -> torch.fx.GraphModule:
4343
"""
4444
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
@@ -48,7 +48,7 @@ def transform(
4848
Arguments:
4949
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
5050
inputs (torch.Tensor): Torch input tensors
51-
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
51+
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
5252
5353
Returns an inlined torch.fx.GraphModule
5454
"""
@@ -57,7 +57,7 @@ def transform(
5757
gm = copy.deepcopy(gm)
5858

5959
# Inline TensorRT submodules
60-
inline_trt_modules(gm, cross_compile_flag)
60+
inline_trt_modules(gm, cross_compile_module)
6161

6262
# Inline pytorch submodules
6363
inline_torch_modules(gm)
@@ -356,7 +356,7 @@ def create_trt_exp_program(
356356

357357

358358
def inline_trt_modules(
359-
gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False
359+
gm: torch.fx.GraphModule, cross_compile_module: Optional[bool] = False
360360
) -> torch.fx.GraphModule:
361361
"""
362362
Replace TRT submodules with trt engine nodes.
@@ -380,7 +380,16 @@ def inline_trt_modules(
380380
num_outputs = len(trt_module_node.meta["val"])
381381
# Insert a call_function node to perform inference on TRT engine
382382
with gm.graph.inserting_before(trt_module_node):
383-
if not cross_compile_flag:
383+
if cross_compile_module:
384+
engine_info = trt_module._pack_engine_info()
385+
engine_bytes = engine_info[ENGINE_IDX]
386+
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
387+
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
388+
trt_node = gm.graph.call_function(
389+
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
390+
(trt_module_node.args, *engine_info),
391+
)
392+
else:
384393
# for the normal workflow: use the execute_engine node
385394
engine_name = f"{name}_engine"
386395
setattr(gm, engine_name, trt_module.engine)
@@ -396,16 +405,6 @@ def inline_trt_modules(
396405
engine_node.meta["val"] = CustomObjArgument(
397406
name=engine_node.name, class_fqn=""
398407
)
399-
else:
400-
# for the cross compile for windows workflow: use the no_op_placeholder node
401-
engine_info = trt_module._pack_engine_info()
402-
engine_bytes = engine_info[ENGINE_IDX]
403-
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
404-
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
405-
trt_node = gm.graph.call_function(
406-
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
407-
(trt_module_node.args, *engine_info),
408-
)
409408
# set trt_node.meta with trt_module_node.meta
410409
assert num_outputs > 0
411410
trt_node.meta["val"] = trt_module_node.meta["val"]
@@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node(
464463
name=engine_node.name, class_fqn=""
465464
)
466465

467-
if len(no_op_placeholder_node.meta["val"]) == 1:
468-
with gm.graph.inserting_after(trt_node):
469-
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
470-
getitem_output.meta["val"] = trt_node.meta["val"]
471-
no_op_placeholder_node.replace_all_uses_with(getitem_output)
472-
else:
473-
no_op_placeholder_node.replace_all_uses_with(trt_node)
474-
getitem_nodes = trt_node.users
475-
for idx, getitem_node in enumerate(getitem_nodes):
476-
getitem_node.meta["val"] = trt_node.meta["val"][idx]
466+
no_op_placeholder_node.replace_all_uses_with(trt_node)
467+
getitem_nodes = trt_node.users
468+
for idx, getitem_node in enumerate(getitem_nodes):
469+
getitem_node.meta["val"] = trt_node.meta["val"][idx]
477470

478471
gm.graph.erase_node(no_op_placeholder_node)
479472

Diff for: py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py

+21
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,24 @@ def __setstate__(self, serialized_state: List[str]) -> Any:
150150

151151
def __getstate__(self) -> Any:
152152
pass
153+
154+
155+
@torch.library.custom_op(
156+
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
157+
)
158+
def no_op_placeholder_for_execute_engine(
159+
inputs: List[torch.Tensor],
160+
abi_version: str,
161+
name: str,
162+
serialized_device_info: str,
163+
serialized_engine: str,
164+
serialized_in_binding_names: str,
165+
serialized_out_binding_names: str,
166+
serialized_hardware_compatible: str,
167+
serialized_metadata: str,
168+
serialized_target_platform: str,
169+
serialized_require_output_allocator: str,
170+
) -> List[torch.Tensor]:
171+
raise RuntimeError(
172+
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."
173+
)

Diff for: py/torch_tensorrt/runtime/_utils.py

-20
Original file line numberDiff line numberDiff line change
@@ -128,23 +128,3 @@ def _get_most_compatible_device(
128128
best_match = candidate
129129

130130
return best_match
131-
132-
133-
@torch.library.custom_op(
134-
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
135-
)
136-
def no_op_placeholder_for_execute_engine(
137-
inputs: List[torch.Tensor],
138-
abi_version: str,
139-
name: str,
140-
serialized_device_info: str,
141-
serialized_engine: str,
142-
serialized_in_binding_names: str,
143-
serialized_out_binding_names: str,
144-
serialized_hardware_compatible: str,
145-
serialized_metadata: str,
146-
serialized_target_platform: str,
147-
) -> List[torch.Tensor]:
148-
raise RuntimeError(
149-
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."
150-
)

Diff for: tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py

+28
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,31 @@ def forward(self, a, b):
6363
)
6464
except Exception as e:
6565
pytest.fail(f"unexpected exception raised: {e}")
66+
67+
@unittest.skipIf(
68+
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
69+
"Cross compile for windows can only be enabled on linux x86-64 platform",
70+
)
71+
@pytest.mark.unit
72+
def test_dynamo_cross_compile_for_windows_multiple_output(self):
73+
class Add(torch.nn.Module):
74+
def forward(self, a, b):
75+
return torch.add(a, b), torch.add(a, b)
76+
77+
model = Add().eval().cuda()
78+
inputs = (torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda())
79+
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")
80+
exp_program = torch.export.export(model, inputs)
81+
compile_spec = {
82+
"inputs": inputs,
83+
"min_block_size": 1,
84+
}
85+
try:
86+
trt_gm = torch_tensorrt.dynamo.cross_compile_for_windows(
87+
exp_program, **compile_spec
88+
)
89+
torch_tensorrt.dynamo.save_cross_compiled_exported_program(
90+
trt_gm, file_path=trt_ep_path
91+
)
92+
except Exception as e:
93+
pytest.fail(f"unexpected exception raised: {e}")

0 commit comments

Comments
 (0)