Skip to content

Commit 3adafaf

Browse files
committed
restructuring the code to include option use_distributed_mode_trace
1 parent 4b79bfb commit 3adafaf

File tree

6 files changed

+38
-36
lines changed

6 files changed

+38
-36
lines changed

examples/distributed_inference/tensor_parallel_llama3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"use_python_runtime": True,
5252
"workspace_size": 1 << 33,
5353
"debug": False,
54-
"use_aot_joint_export": False,
54+
"use_distributed_mode_trace": True,
5555
},
5656
dynamic=False,
5757
)

examples/distributed_inference/tensor_parallel_simple_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def forward(self, x):
7474
"enabled_precisions": {torch.float32, torch.float16},
7575
"use_python_runtime": True,
7676
"min_block_size": 1,
77-
"use_aot_joint_export": False,
77+
"use_distributed_mode_trace": True,
7878
},
7979
dynamic=False,
8080
)

py/torch_tensorrt/dynamo/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
IMMUTABLE_WEIGHTS = True
4747
ENABLE_WEIGHT_STREAMING = False
4848
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
49-
USE_AOT_JOINT_EXPORT = True
49+
USE_DISTRIBUTED_MODE_TRACE = False
5050

5151

5252
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
STRIP_ENGINE_WEIGHTS,
3434
TIMING_CACHE_PATH,
3535
TRUNCATE_DOUBLE,
36-
USE_AOT_JOINT_EXPORT,
36+
USE_DISTRIBUTED_MODE_TRACE,
3737
USE_EXPLICIT_TYPING,
3838
USE_FAST_PARTITIONER,
3939
USE_FP32_ACC,
@@ -92,7 +92,7 @@ class CompilationSettings:
9292
enable_weight_streaming (bool): Enable weight streaming.
9393
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
9494
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
95-
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
95+
USE_DISTRIBUTED_MODE_TRACE (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
9696
"""
9797

9898
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -133,7 +133,7 @@ class CompilationSettings:
133133
immutable_weights: bool = IMMUTABLE_WEIGHTS
134134
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
135135
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
136-
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
136+
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
137137

138138

139139
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/backend/backends.py

+30-28
Original file line numberDiff line numberDiff line change
@@ -52,33 +52,35 @@ def aot_torch_tensorrt_aten_backend(
5252
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
5353
) -> torch.nn.Module:
5454
settings, engine_cache = parse_dynamo_kwargs(kwargs)
55-
if settings.use_aot_joint_export:
56-
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
57-
logger.debug("Wrapping the backend with aot_autograd\n")
58-
_pretraced_backend_autograd = functools.partial(
59-
_pretraced_backend, settings=settings, engine_cache=engine_cache
60-
)
61-
settings_aot_autograd = {}
62-
settings_aot_autograd["decompositions"] = get_decompositions(
63-
settings.enable_experimental_decompositions
64-
)
65-
# This is added since detach lowering leads to alias nodes
66-
# Error - View operation returned a tensor that is the same as the input base tensor
67-
# torch nop_decompositions in torch/_decomp/decompositions.py
68-
# transpose key deleted since not desirable to lower it to permute
69-
to_delete = {
70-
key
71-
for key in settings_aot_autograd["decompositions"]
72-
if "transpose" in key._name or "detach" in key._name
73-
}
74-
75-
for key in to_delete:
76-
del settings_aot_autograd["decompositions"][key]
77-
78-
return aot_autograd(
79-
fw_compiler=_pretraced_backend_autograd,
80-
decompositions=settings_aot_autograd["decompositions"],
81-
)(gm, sample_inputs)
55+
56+
if settings.use_distributed_mode_trace:
57+
logger.debug(
58+
"Wrapping the backend with aot_autograd for Distributed examples\n"
59+
)
60+
_pretraced_backend_autograd = functools.partial(
61+
_pretraced_backend, settings=settings, engine_cache=engine_cache
62+
)
63+
settings_aot_autograd = {}
64+
settings_aot_autograd["decompositions"] = get_decompositions(
65+
settings.enable_experimental_decompositions
66+
)
67+
# This is added since detach lowering leads to alias nodes
68+
# Error - View operation returned a tensor that is the same as the input base tensor
69+
# torch nop_decompositions in torch/_decomp/decompositions.py
70+
# transpose key deleted since not desirable to lower it to permute
71+
to_delete = {
72+
key
73+
for key in settings_aot_autograd["decompositions"]
74+
if "transpose" in key._name or "detach" in key._name
75+
}
76+
for key in to_delete:
77+
del settings_aot_autograd["decompositions"][key]
78+
79+
return aot_autograd(
80+
fw_compiler=_pretraced_backend_autograd,
81+
decompositions=settings_aot_autograd["decompositions"],
82+
)(gm, sample_inputs)
83+
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
8284

8385

8486
def _pretraced_backend(
@@ -129,7 +131,7 @@ def _pretraced_backend(
129131
)
130132

131133
# Invoke AOTAutograd to translate operators to aten
132-
if settings.use_aot_joint_export:
134+
if not settings.use_distributed_mode_trace:
133135
gm = aot_export_joint_simple(
134136
gm,
135137
sample_inputs,

tests/py/dynamo/distributed/test_nccl_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
class TestGatherNcclOpsConverter(DispatchTestCase):
20-
@parameterized.expand([(8)])
20+
@parameterized.expand([8])
2121
def test_nccl_ops(self, linear_layer_dim):
2222
class DistributedGatherModel(nn.Module):
2323
def __init__(self, input_dim):
@@ -42,7 +42,7 @@ def forward(self, x):
4242
enable_passes=True,
4343
)
4444

45-
@parameterized.expand([(8)])
45+
@parameterized.expand([8])
4646
def test_nccl_ops_scatter(self, linear_layer_dim):
4747

4848
class DistributedReduceScatterModel(nn.Module):

0 commit comments

Comments
 (0)