From 108b0ec93023ebecff1d5d0985cd95f7d1fd2997 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 14 Apr 2025 13:31:32 -0700 Subject: [PATCH 1/2] slight code reorg and bug correction for cross_compile --- py/torch_tensorrt/dynamo/_compiler.py | 2 +- py/torch_tensorrt/dynamo/_exporter.py | 49 ++++++++++++--------------- py/torch_tensorrt/runtime/_utils.py | 1 + 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6928347baa..acd16a32f0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1206,7 +1206,7 @@ def save_cross_compiled_exported_program( from torch_tensorrt.dynamo._exporter import export - exp_program = export(gm, cross_compile_flag=True) + exp_program = export(gm, cross_compile_module=True) torch.export.save(exp_program, file_path) logger.debug(f"successfully saved the module for windows at {file_path}") diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index f2d4cfee88..17e0ad4561 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -22,23 +22,23 @@ def export( gm: torch.fx.GraphModule, - cross_compile_flag: Optional[bool] = False, + cross_compile_module: Optional[bool] = False, ) -> ExportedProgram: """Export the result of TensorRT compilation into the desired output format. Arguments: gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` inputs (torch.Tensor): Torch input tensors - cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not + cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not """ - patched_module = transform(gm, cross_compile_flag) + patched_module = transform(gm, cross_compile_module) exp_program = create_trt_exp_program(patched_module) return exp_program def transform( gm: torch.fx.GraphModule, - cross_compile_flag: Optional[bool] = False, + cross_compile_module: Optional[bool] = False, ) -> torch.fx.GraphModule: """ Transforms the graphmodule by inlining Pytorch and TensorRT submodules. @@ -48,7 +48,7 @@ def transform( Arguments: gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` inputs (torch.Tensor): Torch input tensors - cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not + cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not Returns an inlined torch.fx.GraphModule """ @@ -57,7 +57,7 @@ def transform( gm = copy.deepcopy(gm) # Inline TensorRT submodules - inline_trt_modules(gm, cross_compile_flag) + inline_trt_modules(gm, cross_compile_module) # Inline pytorch submodules inline_torch_modules(gm) @@ -356,7 +356,7 @@ def create_trt_exp_program( def inline_trt_modules( - gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False + gm: torch.fx.GraphModule, cross_compile_module: Optional[bool] = False ) -> torch.fx.GraphModule: """ Replace TRT submodules with trt engine nodes. @@ -380,7 +380,16 @@ def inline_trt_modules( num_outputs = len(trt_module_node.meta["val"]) # Insert a call_function node to perform inference on TRT engine with gm.graph.inserting_before(trt_module_node): - if not cross_compile_flag: + if cross_compile_module: + engine_info = trt_module._pack_engine_info() + engine_bytes = engine_info[ENGINE_IDX] + engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") + # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows + trt_node = gm.graph.call_function( + torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, + (trt_module_node.args, *engine_info), + ) + else: # for the normal workflow: use the execute_engine node engine_name = f"{name}_engine" setattr(gm, engine_name, trt_module.engine) @@ -396,16 +405,6 @@ def inline_trt_modules( engine_node.meta["val"] = CustomObjArgument( name=engine_node.name, class_fqn="" ) - else: - # for the cross compile for windows workflow: use the no_op_placeholder node - engine_info = trt_module._pack_engine_info() - engine_bytes = engine_info[ENGINE_IDX] - engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") - # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows - trt_node = gm.graph.call_function( - torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, - (trt_module_node.args, *engine_info), - ) # set trt_node.meta with trt_module_node.meta assert num_outputs > 0 trt_node.meta["val"] = trt_module_node.meta["val"] @@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node( name=engine_node.name, class_fqn="" ) - if len(no_op_placeholder_node.meta["val"]) == 1: - with gm.graph.inserting_after(trt_node): - getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) - getitem_output.meta["val"] = trt_node.meta["val"] - no_op_placeholder_node.replace_all_uses_with(getitem_output) - else: - no_op_placeholder_node.replace_all_uses_with(trt_node) - getitem_nodes = trt_node.users - for idx, getitem_node in enumerate(getitem_nodes): - getitem_node.meta["val"] = trt_node.meta["val"][idx] + no_op_placeholder_node.replace_all_uses_with(trt_node) + getitem_nodes = trt_node.users + for idx, getitem_node in enumerate(getitem_nodes): + getitem_node.meta["val"] = trt_node.meta["val"][idx] gm.graph.erase_node(no_op_placeholder_node) diff --git a/py/torch_tensorrt/runtime/_utils.py b/py/torch_tensorrt/runtime/_utils.py index c42a2b2a2b..3e3062c846 100644 --- a/py/torch_tensorrt/runtime/_utils.py +++ b/py/torch_tensorrt/runtime/_utils.py @@ -144,6 +144,7 @@ def no_op_placeholder_for_execute_engine( serialized_hardware_compatible: str, serialized_metadata: str, serialized_target_platform: str, + serialized_require_output_allocator: str, ) -> List[torch.Tensor]: raise RuntimeError( "The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api." From f8f0f55903f9f3d4b0d46729eb0c642586126713 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 18 Apr 2025 14:48:54 -0700 Subject: [PATCH 2/2] adding test case for multiple outputs and moving op to register_meta_ops --- .../runtime/meta_ops/register_meta_ops.py | 21 ++++++++++++++ py/torch_tensorrt/runtime/_utils.py | 21 -------------- .../test_003_cross_compile_for_windows.py | 28 +++++++++++++++++++ 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index 1b6963fa50..500a665688 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -150,3 +150,24 @@ def __setstate__(self, serialized_state: List[str]) -> Any: def __getstate__(self) -> Any: pass + + +@torch.library.custom_op( + "tensorrt::no_op_placeholder_for_execute_engine", mutates_args=() +) +def no_op_placeholder_for_execute_engine( + inputs: List[torch.Tensor], + abi_version: str, + name: str, + serialized_device_info: str, + serialized_engine: str, + serialized_in_binding_names: str, + serialized_out_binding_names: str, + serialized_hardware_compatible: str, + serialized_metadata: str, + serialized_target_platform: str, + serialized_require_output_allocator: str, +) -> List[torch.Tensor]: + raise RuntimeError( + "The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api." + ) diff --git a/py/torch_tensorrt/runtime/_utils.py b/py/torch_tensorrt/runtime/_utils.py index 3e3062c846..bc2e5a6a70 100644 --- a/py/torch_tensorrt/runtime/_utils.py +++ b/py/torch_tensorrt/runtime/_utils.py @@ -128,24 +128,3 @@ def _get_most_compatible_device( best_match = candidate return best_match - - -@torch.library.custom_op( - "tensorrt::no_op_placeholder_for_execute_engine", mutates_args=() -) -def no_op_placeholder_for_execute_engine( - inputs: List[torch.Tensor], - abi_version: str, - name: str, - serialized_device_info: str, - serialized_engine: str, - serialized_in_binding_names: str, - serialized_out_binding_names: str, - serialized_hardware_compatible: str, - serialized_metadata: str, - serialized_target_platform: str, - serialized_require_output_allocator: str, -) -> List[torch.Tensor]: - raise RuntimeError( - "The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api." - ) diff --git a/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py b/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py index 867bf14bee..44a14a74de 100644 --- a/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py +++ b/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py @@ -63,3 +63,31 @@ def forward(self, a, b): ) except Exception as e: pytest.fail(f"unexpected exception raised: {e}") + + @unittest.skipIf( + platform.system() != "Linux" or platform.architecture()[0] != "64bit", + "Cross compile for windows can only be enabled on linux x86-64 platform", + ) + @pytest.mark.unit + def test_dynamo_cross_compile_for_windows_multiple_output(self): + class Add(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b), torch.add(a, b) + + model = Add().eval().cuda() + inputs = (torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda()) + trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep") + exp_program = torch.export.export(model, inputs) + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + } + try: + trt_gm = torch_tensorrt.dynamo.cross_compile_for_windows( + exp_program, **compile_spec + ) + torch_tensorrt.dynamo.save_cross_compiled_exported_program( + trt_gm, file_path=trt_ep_path + ) + except Exception as e: + pytest.fail(f"unexpected exception raised: {e}")