Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nccl ops correction changes #3387

Merged
merged 14 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
538 changes: 0 additions & 538 deletions examples/distributed_inference/llama3_model.py

This file was deleted.

70 changes: 0 additions & 70 deletions examples/distributed_inference/tensor_parallel_llama3.py

This file was deleted.

42 changes: 23 additions & 19 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import initialize_distributed_env
Expand All @@ -15,7 +16,6 @@
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)
import tensorrt_llm

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
Expand Down Expand Up @@ -65,7 +65,6 @@ def forward(self, x):
inp = torch.rand(20, 10, device="cuda")
python_result = tp_model(inp)


backend = "torch_tensorrt"
tp_model = torch.compile(
tp_model,
Expand All @@ -75,23 +74,28 @@ def forward(self, x):
"enabled_precisions": {torch.float32, torch.float16},
"use_python_runtime": True,
"min_block_size": 1,
"use_aot_joint_export": False,
"use_distributed_mode_trace": True,
},
dynamic=False,
dynamic=None,
)

for i in range(10):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10, device="cuda")
start = time.time()
output = tp_model(inp)
end = time.time()
if i == 0:
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
logger.info(f"Inference time is {end-start}")
try:
for i in range(10):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10, device="cuda")
start = time.time()
output = tp_model(inp)
end = time.time()
if i == 0:
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
logger.info(f"Inference time is {end-start}")
finally:
# This cleans up the distributed process group
if dist.is_initialized():
dist.destroy_process_group()
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
IMMUTABLE_WEIGHTS = True
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
USE_AOT_JOINT_EXPORT = True
TILING_OPTIMIZATION_LEVEL = "none"
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False


def default_device() -> Device:
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TILING_OPTIMIZATION_LEVEL,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_AOT_JOINT_EXPORT,
USE_DISTRIBUTED_MODE_TRACE,
USE_EXPLICIT_TYPING,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
Expand Down Expand Up @@ -94,9 +94,9 @@ class CompilationSettings:
enable_weight_streaming (bool): Enable weight streaming.
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -137,9 +137,9 @@ class CompilationSettings:
immutable_weights: bool = IMMUTABLE_WEIGHTS
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
72 changes: 35 additions & 37 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch.distributed.tensor import DTensor
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._compiler import compile_module
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
modify_reshape_complex_nodes,
post_lowering,
remove_detach,
remove_sym_nodes,
Expand Down Expand Up @@ -52,25 +52,39 @@ def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings, engine_cache = parse_dynamo_kwargs(kwargs)
if settings.use_aot_joint_export:
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
logger.debug("Wrapping the backend with aot_autograd\n")
_pretraced_backend_autograd = functools.partial(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompostions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)(gm, sample_inputs)

if settings.use_distributed_mode_trace:
logger.debug(
"Wrapping the backend with aot_autograd for Distributed examples\n"
)
Comment on lines +58 to +59
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider changing the message to - "Using aot_autograd to trace the graph. Enable this only if the model includes distributed operations"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about only,

  1. have a warning for when users should use this trace mode but arent
  2. have an Info level message when this mode is being used Using AOTAutograd tracer for model lowering or something to that extent

_pretraced_backend_autograd = functools.partial(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompositions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
# transpose key deleted since not desirable to lower it to permute
to_delete = {
key
for key in settings_aot_autograd["decompositions"]
if "detach" in key._name
}
for key in to_delete:
del settings_aot_autograd["decompositions"][key]

return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=settings_aot_autograd["decompositions"],
)(gm, sample_inputs)
if any(isinstance(tensor, DTensor) for tensor in sample_inputs):
logger.warning(
"It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple"
)
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)


def _pretraced_backend(
Expand Down Expand Up @@ -110,18 +124,8 @@ def _pretraced_backend(
# Remove detach nodes
remove_detach(gm, settings)

complexInputIndices = []
for i, torch_input in enumerate(torch_inputs):
if torch_inputs[i].dtype == torch.complex64:
complexInputIndices.append(i)
torch_input_real = torch_inputs[i].real
torch_input_imaginary = torch_inputs[i].imag
torch_inputs[i] = torch.stack(
(torch_input_real, torch_input_imaginary), dim=-1
)

# Invoke AOTAutograd to translate operators to aten
if settings.use_aot_joint_export:
if not settings.use_distributed_mode_trace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this check now ? since we return the aot_autograd block before ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we would be needing this since while wrapping it in aot_autograd , we need to pass the pretraced_backend still to the fw_compiler arg in aot_autograd.
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn) But then we won't want it to do aot_joint_export tracing there

gm = aot_export_joint_simple(
gm,
sample_inputs,
Expand All @@ -137,12 +141,6 @@ def _pretraced_backend(

logger.debug("Lowered Input graph:\n " + str(gm.graph))

if complexInputIndices:
modify_reshape_complex_nodes(gm, complexInputIndices)
logger.debug(
"Input graph after modifying complex nodes:\n " + str(gm.graph)
)

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
Expand Down
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Dict, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -16,8 +17,6 @@
tensorrt_fused_nccl_reduce_scatter_op,
)

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():
Expand All @@ -30,7 +29,7 @@ def fused_nccl_gather(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_gather(
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
Expand All @@ -46,15 +45,14 @@ def fused_nccl_reduce_scatter(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_reduce_scatter(
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

breakpoint()
else:
_LOGGER.debug(
"Did not load torch.distributed converters since TensorRT-LLM is not available"
Expand Down
5 changes: 2 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from typing import Optional, Tuple, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name

import tensorrt as trt


# class for AllReduce
class AllReduceStrategy(IntEnum):
Expand Down Expand Up @@ -94,7 +93,7 @@ def nccl_reduce_scatter(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)

p_dtype = trt.float16
p_dtype = trt.float32
pf_dtype = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def fuse_distributed_ops(
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
Expand All @@ -58,11 +57,12 @@ def fuse_distributed_ops(
args=(node.args[0], node.args[1], node.args[2]),
)
else:
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
Expand Down
Loading