Skip to content

Commit 9cf0e21

Browse files
committed
adding the test script and correction to the backend
1 parent d2aaa44 commit 9cf0e21

File tree

6 files changed

+308
-27
lines changed

6 files changed

+308
-27
lines changed

Diff for: py/torch_tensorrt/dynamo/backend/backends.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,12 @@ def aot_torch_tensorrt_aten_backend(
6969
to_delete = {
7070
key
7171
for key in settings_aot_autograd["decompositions"]
72-
if "transpose" in key._name
72+
if "transpose" in key._name or "detach" in key._name
7373
}
7474

7575
for key in to_delete:
7676
del settings_aot_autograd["decompositions"][key]
7777

78-
remove_detach(gm, settings)
7978
return aot_autograd(
8079
fw_compiler=_pretraced_backend_autograd,
8180
decompositions=settings_aot_autograd["decompositions"],

Diff for: tests/py/dynamo/conversion/harness.py

-13
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,6 @@ def generate_graph(
351351
enable_passes: bool,
352352
propagate_shapes: bool = False,
353353
settings: CompilationSettings = CompilationSettings(),
354-
fuse_distributed_ops: bool = False,
355354
torch_export_dynamic_shapes: Optional[Any] = None,
356355
):
357356
mod = mod.eval()
@@ -367,16 +366,6 @@ def generate_graph(
367366
tuple(torch_export_inputs),
368367
dynamic_shapes=torch_export_dynamic_shapes,
369368
)
370-
if fuse_distributed_ops:
371-
exported_program = exported_program.run_decompositions(
372-
get_decompositions(False)
373-
)
374-
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
375-
fuse_distributed_ops,
376-
)
377-
378-
gm = exported_program.graph_module
379-
gm = fuse_distributed_ops(gm, settings)
380369
if enable_passes:
381370
exported_program = pre_export_lowering(exported_program, settings)
382371
exported_program = exported_program.run_decompositions(
@@ -415,7 +404,6 @@ def run_test(
415404
propagate_shapes=False,
416405
int32_reqd=False,
417406
immutable_weights=True,
418-
fuse_distributed_ops=False,
419407
):
420408
# TODO: lan to remove this and set use_dynamo_traccer to True by default
421409
# once all the converter test files are moved to use_dynamo_tracer
@@ -436,7 +424,6 @@ def run_test(
436424
enable_passes=enable_passes,
437425
propagate_shapes=propagate_shapes,
438426
settings=compilation_settings,
439-
fuse_distributed_ops=fuse_distributed_ops,
440427
)
441428

442429
num_inputs = len(inputs)

Diff for: tests/py/dynamo/distributed/distributed_utils.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import logging
2+
import os
3+
4+
import numpy as np
5+
import tensorrt as trt
6+
import torch
7+
import torch.distributed as dist
8+
from torch.distributed._tensor.device_mesh import init_device_mesh
9+
10+
11+
def set_environment_variables_pytest():
12+
os.environ["WORLD_SIZE"] = str(1)
13+
os.environ["RANK"] = str(0)
14+
os.environ["MASTER_ADDR"] = "127.0.0.1"
15+
os.environ["MASTER_PORT"] = str(29500)
16+
os.environ["USE_TRTLLM_PLUGINS"] = "1"
17+
18+
19+
def find_repo_root(max_depth=10):
20+
dir_path = os.path.dirname(os.path.realpath(__file__))
21+
for i in range(max_depth):
22+
files = os.listdir(dir_path)
23+
if "MODULE.bazel" in files:
24+
return dir_path
25+
else:
26+
dir_path = os.path.dirname(dir_path)
27+
28+
raise RuntimeError("Could not find repo root")
29+
30+
31+
def initialize_logger(rank, logger_file_name):
32+
logger = logging.getLogger()
33+
logger.setLevel(logging.INFO)
34+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
35+
fh.setLevel(logging.INFO)
36+
logger.addHandler(fh)
37+
return logger
38+
39+
40+
# This is required for env initialization since we use mpirun
41+
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
42+
local_rank = int(
43+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
44+
)
45+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
46+
47+
# Set up environment variable to run with mpirun
48+
os.environ["RANK"] = str(local_rank)
49+
os.environ["WORLD_SIZE"] = str(world_size)
50+
os.environ["MASTER_ADDR"] = "127.0.0.1"
51+
os.environ["MASTER_PORT"] = str(port)
52+
os.environ["TRTLLM_PLUGINS_PATH"] = (
53+
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
54+
)
55+
56+
# Necessary to assign a device to each rank.
57+
torch.cuda.set_device(local_rank)
58+
59+
# We use nccl backend
60+
dist.init_process_group("nccl")
61+
62+
# set a manual seed for reproducibility
63+
torch.manual_seed(1111)
64+
65+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
66+
rank = device_mesh.get_rank()
67+
assert rank == local_rank
68+
logger = initialize_logger(rank, logger_file_name)
69+
device_id = (
70+
rank % torch.cuda.device_count()
71+
) # Ensure each rank gets a unique device
72+
torch.cuda.set_device(device_id)
73+
74+
return device_mesh, world_size, rank, logger
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import time
2+
3+
import tensorrt as trt
4+
import torch
5+
import torch.nn as nn
6+
import torch_tensorrt
7+
from distributed_utils import initialize_distributed_env
8+
from torch.distributed._tensor import Shard
9+
from torch.distributed.tensor.parallel import (
10+
ColwiseParallel,
11+
RowwiseParallel,
12+
parallelize_module,
13+
)
14+
15+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
16+
"./tensor_parallel_simple_example"
17+
)
18+
19+
"""
20+
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
21+
"""
22+
23+
24+
class ToyModel(nn.Module):
25+
"""MLP based model"""
26+
27+
def __init__(self):
28+
super(ToyModel, self).__init__()
29+
self.in_proj = nn.Linear(10, 3200)
30+
self.relu = nn.ReLU()
31+
self.out_proj = nn.Linear(3200, 1600)
32+
self.in_proj2 = nn.Linear(1600, 500)
33+
self.out_proj2 = nn.Linear(500, 100)
34+
35+
def forward(self, x):
36+
x = self.out_proj(self.relu(self.in_proj(x)))
37+
x = self.relu(x)
38+
x = self.out_proj2(self.relu(self.in_proj2(x)))
39+
return x
40+
41+
42+
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
43+
44+
# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
45+
tp_model = ToyModel().to("cuda")
46+
47+
48+
# Custom parallelization plan for the model
49+
tp_model = parallelize_module(
50+
module=tp_model,
51+
device_mesh=device_mesh,
52+
parallelize_plan={
53+
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
54+
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
55+
"in_proj2": ColwiseParallel(input_layouts=Shard(0)),
56+
"out_proj2": RowwiseParallel(output_layouts=Shard(0)),
57+
},
58+
)
59+
torch.manual_seed(0)
60+
inp = torch.rand(20, 10, device="cuda")
61+
python_result = tp_model(inp)
62+
63+
64+
backend = "torch_tensorrt"
65+
tp_model = torch.compile(
66+
tp_model,
67+
backend=backend,
68+
options={
69+
"truncate_long_and_double": True,
70+
"enabled_precisions": {torch.float32, torch.float16},
71+
"use_python_runtime": True,
72+
"min_block_size": 1,
73+
"use_aot_joint_export": False,
74+
},
75+
dynamic=False,
76+
)
77+
78+
for i in range(10):
79+
# For TP, input needs to be same across all TP ranks.
80+
# Setting the random seed is to mimic the behavior of dataloader.
81+
torch.manual_seed(i)
82+
inp = torch.rand(20, 10, device="cuda")
83+
start = time.time()
84+
output = tp_model(inp)
85+
end = time.time()
86+
if i == 0:
87+
logger.info(f"Compilation time is {end-start}")
88+
assert (
89+
python_result - output
90+
).std() < 0.01, "Compilation result is not correct."
91+
elif _rank == 0:
92+
logger.info(f"Inference time is {end-start}")

Diff for: tests/py/dynamo/distributed/test_nccl_ops.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,11 @@
33
import torch
44
import torch.distributed as dist
55
import torch.nn as nn
6+
from distributed_utils import set_environment_variables_pytest
67
from parameterized import parameterized
78
from torch.testing._internal.common_utils import run_tests
89

9-
10-
def set_environment_variables():
11-
os.environ["WORLD_SIZE"] = str(1)
12-
os.environ["RANK"] = str(0)
13-
os.environ["MASTER_ADDR"] = "127.0.0.1"
14-
os.environ["MASTER_PORT"] = str(29500)
15-
os.environ["USE_TRTLLM_PLUGINS"] = "1"
16-
17-
18-
set_environment_variables()
10+
set_environment_variables_pytest()
1911
dist.init_process_group(backend="nccl", init_method="env://")
2012
group = dist.new_group(ranks=[0])
2113
group_name = group.group_name
@@ -47,7 +39,7 @@ def forward(self, x):
4739
DistributedGatherModel(linear_layer_dim).cuda(),
4840
inputs,
4941
use_dynamo_tracer=True,
50-
fuse_distributed_ops=True,
42+
enable_passes=True,
5143
)
5244

5345
@parameterized.expand([(8)])
@@ -76,7 +68,7 @@ def forward(self, x):
7668
DistributedReduceScatterModel(linear_layer_dim).cuda(),
7769
inputs,
7870
use_dynamo_tracer=True,
79-
fuse_distributed_ops=True,
71+
enable_passes=True,
8072
)
8173

8274

0 commit comments

Comments
 (0)