Skip to content

Commit b040d81

Browse files
apboseperi044
authored andcommitted
nccl ops correction
1 parent c7d610a commit b040d81

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from typing import Dict, Sequence, Tuple, Union
55

6+
import tensorrt as trt
67
from torch.fx.node import Argument, Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
89
from torch_tensorrt.dynamo.conversion import impl
@@ -16,8 +17,6 @@
1617
tensorrt_fused_nccl_reduce_scatter_op,
1718
)
1819

19-
import tensorrt as trt
20-
2120
_LOGGER: logging.Logger = logging.getLogger(__name__)
2221

2322
if load_tensorrt_llm():
@@ -30,7 +29,7 @@ def fused_nccl_gather(
3029
kwargs: Dict[str, Argument],
3130
name: str,
3231
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
33-
return impl.distributed.nccl_gather(
32+
return impl.nccl_ops.nccl_gather(
3433
ctx,
3534
target,
3635
SourceIR.ATEN,
@@ -46,15 +45,14 @@ def fused_nccl_reduce_scatter(
4645
kwargs: Dict[str, Argument],
4746
name: str,
4847
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
49-
return impl.distributed.nccl_reduce_scatter(
48+
return impl.nccl_ops.nccl_reduce_scatter(
5049
ctx,
5150
target,
5251
SourceIR.ATEN,
5352
name,
5453
[args[0]],
5554
)
5655

57-
breakpoint()
5856
else:
5957
_LOGGER.debug(
6058
"Did not load torch.distributed converters since TensorRT-LLM is not available"

py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None:
106106

107107
if op_target in shape_inference_funcs:
108108
new_shape = shape_inference_funcs[op_target](node)
109-
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
109+
new_node_dtype = None
110+
if node.meta["val"].dtype == torch.complex64:
111+
new_node_dtype = torch.float32
112+
else:
113+
new_node_dtype = torch.float64
114+
real_tensor = torch.empty(new_shape, dtype=new_node_dtype)
110115
node.meta["val"] = fake_mode.from_tensor(real_tensor)
111116
else:
112117
print("No shape for the inference function", {op_name})

py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def fuse_distributed_ops(
4949
== torch.ops._c10d_functional.wait_tensor.default
5050
):
5151
wait_tensor_node = list(node.users)[0]
52-
fused_op = None
5352
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
5453
with gm.graph.inserting_after(wait_tensor_node):
5554
fused_node = gm.graph.create_node(
@@ -58,11 +57,12 @@ def fuse_distributed_ops(
5857
args=(node.args[0], node.args[1], node.args[2]),
5958
)
6059
else:
61-
fused_node = gm.graph.create_node(
62-
op="call_function",
63-
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
64-
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
65-
)
60+
with gm.graph.inserting_after(wait_tensor_node):
61+
fused_node = gm.graph.create_node(
62+
op="call_function",
63+
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
64+
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
65+
)
6666

6767
wait_tensor_node.replace_all_uses_with(fused_node)
6868
fused_node.meta.update(node.meta)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+9
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
364364
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
365365
for i in inputs
366366
]
367+
368+
for i, contiguous_input in enumerate(contiguous_inputs):
369+
if contiguous_input.dtype == torch.complex64:
370+
contiguous_input_real = contiguous_input.real
371+
contiguous_input_imag = contiguous_input.imag
372+
contiguous_inputs[i] = torch.stack(
373+
(contiguous_input_real, contiguous_input_imag), dim=-1
374+
)
375+
367376
with (
368377
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
369378
if self.profiling_enabled

0 commit comments

Comments
 (0)