Skip to content

Commit 1fc4d21

Browse files
committed
Allreduce: add support for sparse tensors (#98)
* Fix issue of Mismatch in kernel C++ signatures with latest PyTorch
1 parent 376cc1d commit 1fc4d21

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/ProcessGroupCCL.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@ TORCH_LIBRARY_IMPL(c10d, XPU, m) {
6666
m.impl("broadcast_", broadcast_xpu_);
6767
}
6868

69+
#if TORCH_VERSION_MAJOR > 1 && TORCH_VERSION_MINOR >= 1
70+
// PyTorch 2.1 allreduce support sparse tensor
6971
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu_(
7072
at::TensorList tensors,
7173
const c10::intrusive_ptr<ProcessGroup>& process_group,
7274
const c10::intrusive_ptr<ReduceOp>& reduce_op,
75+
const c10::optional<at::Tensor>& sparse_indices,
7376
int64_t timeout) {
7477
auto tensor_vec = tensors.vec();
7578
auto work =
@@ -85,6 +88,28 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu
8588
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>>(
8689
std::move(tensor_vec), work);
8790
}
91+
#else
92+
// TODO: Remove after updating to PyTorch 2.1
93+
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu_(
94+
at::TensorList tensors,
95+
const c10::intrusive_ptr<ProcessGroup>& process_group,
96+
const c10::intrusive_ptr<ReduceOp>& reduce_op,
97+
int64_t timeout) {
98+
auto tensor_vec = tensors.vec();
99+
auto work =
100+
process_group->getBackend(c10::DeviceType::XPU)
101+
->allreduce(
102+
tensor_vec,
103+
c10d::AllreduceOptions{
104+
*reduce_op.get(), std::chrono::milliseconds(timeout)});
105+
106+
// Return input tensors as output tensors to make inplace allreduce look like
107+
// a functional API, so that make_fx can correctly build the dependencies in
108+
// the graph later.
109+
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>>(
110+
std::move(tensor_vec), work);
111+
}
112+
#endif
88113

89114
TORCH_LIBRARY_IMPL(c10d, XPU, m) {
90115
m.impl("allreduce_", allreduce_xpu_);

0 commit comments

Comments
 (0)