@@ -66,10 +66,13 @@ TORCH_LIBRARY_IMPL(c10d, XPU, m) {
66
66
m.impl (" broadcast_" , broadcast_xpu_);
67
67
}
68
68
69
+ #if TORCH_VERSION_MAJOR > 1 && TORCH_VERSION_MINOR >= 1
70
+ // PyTorch 2.1 allreduce support sparse tensor
69
71
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu_ (
70
72
at::TensorList tensors,
71
73
const c10::intrusive_ptr<ProcessGroup>& process_group,
72
74
const c10::intrusive_ptr<ReduceOp>& reduce_op,
75
+ const c10::optional<at::Tensor>& sparse_indices,
73
76
int64_t timeout) {
74
77
auto tensor_vec = tensors.vec ();
75
78
auto work =
@@ -85,6 +88,28 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu
85
88
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>>(
86
89
std::move (tensor_vec), work);
87
90
}
91
+ #else
92
+ // TODO: Remove after updating to PyTorch 2.1
93
+ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu_ (
94
+ at::TensorList tensors,
95
+ const c10::intrusive_ptr<ProcessGroup>& process_group,
96
+ const c10::intrusive_ptr<ReduceOp>& reduce_op,
97
+ int64_t timeout) {
98
+ auto tensor_vec = tensors.vec ();
99
+ auto work =
100
+ process_group->getBackend (c10::DeviceType::XPU)
101
+ ->allreduce (
102
+ tensor_vec,
103
+ c10d::AllreduceOptions{
104
+ *reduce_op.get (), std::chrono::milliseconds (timeout)});
105
+
106
+ // Return input tensors as output tensors to make inplace allreduce look like
107
+ // a functional API, so that make_fx can correctly build the dependencies in
108
+ // the graph later.
109
+ return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>>(
110
+ std::move (tensor_vec), work);
111
+ }
112
+ #endif
88
113
89
114
TORCH_LIBRARY_IMPL (c10d, XPU, m) {
90
115
m.impl (" allreduce_" , allreduce_xpu_);
0 commit comments