|
10 | 10 | from torch._dynamo.backends.common import aot_autograd
|
11 | 11 | from torch._dynamo.utils import detect_fake_mode
|
12 | 12 | from torch._functorch.aot_autograd import aot_export_joint_simple
|
| 13 | +from torch._ops import OpOverload |
13 | 14 | from torch_tensorrt.dynamo import CompilationSettings
|
14 | 15 | from torch_tensorrt.dynamo._compiler import compile_module
|
15 | 16 | from torch_tensorrt.dynamo.lowering import (
|
@@ -59,17 +60,17 @@ def aot_torch_tensorrt_aten_backend(
|
59 | 60 | _pretraced_backend, settings=settings, engine_cache=engine_cache
|
60 | 61 | )
|
61 | 62 | settings_aot_autograd = {}
|
62 |
| - settings_aot_autograd["decompostions"] = get_decompositions( |
| 63 | + settings_aot_autograd["decompositions"] = get_decompositions( |
63 | 64 | settings.enable_experimental_decompositions
|
64 | 65 | )
|
65 |
| - # This is added since detach lowering leads to alias nodes |
66 |
| - # Error - View operation returned a tensor that is the same as the input base tensor |
67 |
| - # torch nop_decompositions in torch/_decomp/decompositions.py |
68 |
| - if aten.detach in settings_aot_autograd["decompositions"]: |
69 |
| - del settings_aot_autograd["decompositions"][aten.detach] |
| 66 | + # transpose key deleted since not desirable to lower it to permute |
| 67 | + for key in settings_aot_autograd["decompositions"]: |
| 68 | + if "transpose" in key._name: |
| 69 | + to_delete = key |
| 70 | + del settings_aot_autograd["decompositions"][to_delete] |
70 | 71 | return aot_autograd(
|
71 | 72 | fw_compiler=_pretraced_backend_autograd,
|
72 |
| - decompositions=get_decompositions(settings.enable_experimental_decompositions), |
| 73 | + decompositions=settings_aot_autograd["decompositions"], |
73 | 74 | )(gm, sample_inputs)
|
74 | 75 |
|
75 | 76 |
|
|
0 commit comments