22
22
23
23
def export (
24
24
gm : torch .fx .GraphModule ,
25
- cross_compile_flag : Optional [bool ] = False ,
25
+ cross_compile_module : Optional [bool ] = False ,
26
26
) -> ExportedProgram :
27
27
"""Export the result of TensorRT compilation into the desired output format.
28
28
29
29
Arguments:
30
30
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
31
31
inputs (torch.Tensor): Torch input tensors
32
- cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
32
+ cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
33
33
"""
34
- patched_module = transform (gm , cross_compile_flag )
34
+ patched_module = transform (gm , cross_compile_module )
35
35
exp_program = create_trt_exp_program (patched_module )
36
36
return exp_program
37
37
38
38
39
39
def transform (
40
40
gm : torch .fx .GraphModule ,
41
- cross_compile_flag : Optional [bool ] = False ,
41
+ cross_compile_module : Optional [bool ] = False ,
42
42
) -> torch .fx .GraphModule :
43
43
"""
44
44
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
@@ -48,7 +48,7 @@ def transform(
48
48
Arguments:
49
49
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
50
50
inputs (torch.Tensor): Torch input tensors
51
- cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
51
+ cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
52
52
53
53
Returns an inlined torch.fx.GraphModule
54
54
"""
@@ -57,7 +57,7 @@ def transform(
57
57
gm = copy .deepcopy (gm )
58
58
59
59
# Inline TensorRT submodules
60
- inline_trt_modules (gm , cross_compile_flag )
60
+ inline_trt_modules (gm , cross_compile_module )
61
61
62
62
# Inline pytorch submodules
63
63
inline_torch_modules (gm )
@@ -356,7 +356,7 @@ def create_trt_exp_program(
356
356
357
357
358
358
def inline_trt_modules (
359
- gm : torch .fx .GraphModule , cross_compile_flag : Optional [bool ] = False
359
+ gm : torch .fx .GraphModule , cross_compile_module : Optional [bool ] = False
360
360
) -> torch .fx .GraphModule :
361
361
"""
362
362
Replace TRT submodules with trt engine nodes.
@@ -380,7 +380,16 @@ def inline_trt_modules(
380
380
num_outputs = len (trt_module_node .meta ["val" ])
381
381
# Insert a call_function node to perform inference on TRT engine
382
382
with gm .graph .inserting_before (trt_module_node ):
383
- if not cross_compile_flag :
383
+ if cross_compile_module :
384
+ engine_info = trt_module ._pack_engine_info ()
385
+ engine_bytes = engine_info [ENGINE_IDX ]
386
+ engine_info [ENGINE_IDX ] = base64 .b64encode (engine_bytes ).decode ("utf-8" )
387
+ # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
388
+ trt_node = gm .graph .call_function (
389
+ torch .ops .tensorrt .no_op_placeholder_for_execute_engine .default ,
390
+ (trt_module_node .args , * engine_info ),
391
+ )
392
+ else :
384
393
# for the normal workflow: use the execute_engine node
385
394
engine_name = f"{ name } _engine"
386
395
setattr (gm , engine_name , trt_module .engine )
@@ -396,16 +405,6 @@ def inline_trt_modules(
396
405
engine_node .meta ["val" ] = CustomObjArgument (
397
406
name = engine_node .name , class_fqn = ""
398
407
)
399
- else :
400
- # for the cross compile for windows workflow: use the no_op_placeholder node
401
- engine_info = trt_module ._pack_engine_info ()
402
- engine_bytes = engine_info [ENGINE_IDX ]
403
- engine_info [ENGINE_IDX ] = base64 .b64encode (engine_bytes ).decode ("utf-8" )
404
- # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
405
- trt_node = gm .graph .call_function (
406
- torch .ops .tensorrt .no_op_placeholder_for_execute_engine .default ,
407
- (trt_module_node .args , * engine_info ),
408
- )
409
408
# set trt_node.meta with trt_module_node.meta
410
409
assert num_outputs > 0
411
410
trt_node .meta ["val" ] = trt_module_node .meta ["val" ]
@@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node(
464
463
name = engine_node .name , class_fqn = ""
465
464
)
466
465
467
- if len (no_op_placeholder_node .meta ["val" ]) == 1 :
468
- with gm .graph .inserting_after (trt_node ):
469
- getitem_output = gm .graph .call_function (operator .getitem , (trt_node , 0 ))
470
- getitem_output .meta ["val" ] = trt_node .meta ["val" ]
471
- no_op_placeholder_node .replace_all_uses_with (getitem_output )
472
- else :
473
- no_op_placeholder_node .replace_all_uses_with (trt_node )
474
- getitem_nodes = trt_node .users
475
- for idx , getitem_node in enumerate (getitem_nodes ):
476
- getitem_node .meta ["val" ] = trt_node .meta ["val" ][idx ]
466
+ no_op_placeholder_node .replace_all_uses_with (trt_node )
467
+ getitem_nodes = trt_node .users
468
+ for idx , getitem_node in enumerate (getitem_nodes ):
469
+ getitem_node .meta ["val" ] = trt_node .meta ["val" ][idx ]
477
470
478
471
gm .graph .erase_node (no_op_placeholder_node )
479
472
0 commit comments