Skip to content

Commit d2aaa44

Browse files
committed
adding test cases
1 parent 66861f4 commit d2aaa44

File tree

3 files changed

+99
-3
lines changed

3 files changed

+99
-3
lines changed

Diff for: py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from typing import Optional, Tuple, Union
44

55
import numpy as np
6+
import tensorrt as trt
67
from torch.fx.node import Argument, Target
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
89
from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name
910

10-
import tensorrt as trt
11-
1211

1312
# class for AllReduce
1413
class AllReduceStrategy(IntEnum):
@@ -94,7 +93,7 @@ def nccl_reduce_scatter(
9493
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
9594
)
9695

97-
p_dtype = trt.float16
96+
p_dtype = trt.float32
9897
pf_dtype = trt.PluginField(
9998
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
10099
)

Diff for: tests/py/dynamo/conversion/harness.py

+13
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def generate_graph(
351351
enable_passes: bool,
352352
propagate_shapes: bool = False,
353353
settings: CompilationSettings = CompilationSettings(),
354+
fuse_distributed_ops: bool = False,
354355
torch_export_dynamic_shapes: Optional[Any] = None,
355356
):
356357
mod = mod.eval()
@@ -366,6 +367,16 @@ def generate_graph(
366367
tuple(torch_export_inputs),
367368
dynamic_shapes=torch_export_dynamic_shapes,
368369
)
370+
if fuse_distributed_ops:
371+
exported_program = exported_program.run_decompositions(
372+
get_decompositions(False)
373+
)
374+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
375+
fuse_distributed_ops,
376+
)
377+
378+
gm = exported_program.graph_module
379+
gm = fuse_distributed_ops(gm, settings)
369380
if enable_passes:
370381
exported_program = pre_export_lowering(exported_program, settings)
371382
exported_program = exported_program.run_decompositions(
@@ -404,6 +415,7 @@ def run_test(
404415
propagate_shapes=False,
405416
int32_reqd=False,
406417
immutable_weights=True,
418+
fuse_distributed_ops=False,
407419
):
408420
# TODO: lan to remove this and set use_dynamo_traccer to True by default
409421
# once all the converter test files are moved to use_dynamo_tracer
@@ -424,6 +436,7 @@ def run_test(
424436
enable_passes=enable_passes,
425437
propagate_shapes=propagate_shapes,
426438
settings=compilation_settings,
439+
fuse_distributed_ops=fuse_distributed_ops,
427440
)
428441

429442
num_inputs = len(inputs)

Diff for: tests/py/dynamo/distributed/test_nccl_ops.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.nn as nn
6+
from parameterized import parameterized
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
10+
def set_environment_variables():
11+
os.environ["WORLD_SIZE"] = str(1)
12+
os.environ["RANK"] = str(0)
13+
os.environ["MASTER_ADDR"] = "127.0.0.1"
14+
os.environ["MASTER_PORT"] = str(29500)
15+
os.environ["USE_TRTLLM_PLUGINS"] = "1"
16+
17+
18+
set_environment_variables()
19+
dist.init_process_group(backend="nccl", init_method="env://")
20+
group = dist.new_group(ranks=[0])
21+
group_name = group.group_name
22+
world_size = 1
23+
24+
from conversion.harness import DispatchTestCase
25+
26+
27+
class TestGatherNcclOpsConverter(DispatchTestCase):
28+
@parameterized.expand([(8)])
29+
def test_nccl_ops(self, linear_layer_dim):
30+
class DistributedGatherModel(nn.Module):
31+
def __init__(self, input_dim):
32+
super().__init__()
33+
self.fc = torch.nn.Linear(input_dim, input_dim)
34+
35+
def forward(self, x):
36+
x = self.fc(x)
37+
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(
38+
x, world_size, group_name
39+
)
40+
gathered_tensor = torch.ops._c10d_functional.wait_tensor(
41+
gathered_tensor
42+
)
43+
return gathered_tensor
44+
45+
inputs = [torch.randn(1, linear_layer_dim).to("cuda")]
46+
self.run_test(
47+
DistributedGatherModel(linear_layer_dim).cuda(),
48+
inputs,
49+
use_dynamo_tracer=True,
50+
fuse_distributed_ops=True,
51+
)
52+
53+
@parameterized.expand([(8)])
54+
def test_nccl_ops_scatter(self, linear_layer_dim):
55+
56+
class DistributedReduceScatterModel(nn.Module):
57+
def __init__(self, input_dim):
58+
super().__init__()
59+
self.fc = torch.nn.Linear(input_dim, input_dim)
60+
61+
def forward(self, x):
62+
x = self.fc(x)
63+
scatter_reduce_tensor = (
64+
torch.ops._c10d_functional.reduce_scatter_tensor(
65+
x, "sum", world_size, group_name
66+
)
67+
)
68+
scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor(
69+
scatter_reduce_tensor
70+
)
71+
return scatter_reduce_tensor
72+
73+
inputs = [torch.zeros(1, linear_layer_dim).to("cuda")]
74+
75+
self.run_test(
76+
DistributedReduceScatterModel(linear_layer_dim).cuda(),
77+
inputs,
78+
use_dynamo_tracer=True,
79+
fuse_distributed_ops=True,
80+
)
81+
82+
83+
if __name__ == "__main__":
84+
run_tests()

0 commit comments

Comments
 (0)