diff --git a/src/ProcessGroupCCL.cpp b/src/ProcessGroupCCL.cpp index a6d44ff..93ce730 100644 --- a/src/ProcessGroupCCL.cpp +++ b/src/ProcessGroupCCL.cpp @@ -94,6 +94,7 @@ std::tuple, c10::intrusive_ptr> allreduce_xpu const c10::intrusive_ptr& process_group, const c10::intrusive_ptr& reduce_op, const c10::optional& sparse_indices, + bool asyncOp, int64_t timeout) { auto tensor_vec = tensors.vec(); auto work = @@ -101,7 +102,7 @@ std::tuple, c10::intrusive_ptr> allreduce_xpu ->allreduce( tensor_vec, c10d::AllreduceOptions{ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); + *reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); // Return input tensors as output tensors to make inplace allreduce look like // a functional API, so that make_fx can correctly build the dependencies in @@ -140,11 +141,13 @@ c10::intrusive_ptr allreduce_coalesced_xpu_( at::TensorList tensors, const c10::intrusive_ptr& process_group, const c10::intrusive_ptr& reduce_op, + bool asyncOp, int64_t timeout) { auto tensor_vec = tensors.vec(); AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; opts.reduceOp = *reduce_op.get(); opts.timeout = std::chrono::milliseconds(timeout); + opts.asyncOp = asyncOp; return process_group->getBackend(c10::DeviceType::XPU) ->allreduce_coalesced(tensor_vec, opts); @@ -160,6 +163,7 @@ c10::intrusive_ptr reduce_xpu_( const c10::intrusive_ptr& reduce_op, int64_t root_rank, int64_t root_tensor, + bool asyncOp, int64_t timeout) { auto tensor_vec = tensors.vec(); return process_group->getBackend(c10::DeviceType::XPU) @@ -169,7 +173,7 @@ c10::intrusive_ptr reduce_xpu_( *reduce_op.get(), root_rank, root_tensor, - std::chrono::milliseconds(timeout)}); + std::chrono::milliseconds(timeout), asyncOp}); } TORCH_LIBRARY_IMPL(c10d, XPU, m) { @@ -181,6 +185,7 @@ allgather_xpu_( const std::vector>& output_tensors, at::TensorList input_tensors, const c10::intrusive_ptr& process_group, + bool asyncOp, int64_t timeout) { auto input_tensors_vec = input_tensors.vec(); auto work = @@ -188,7 +193,7 @@ allgather_xpu_( ->allgather( const_cast>&>(output_tensors), input_tensors_vec, - AllgatherOptions{std::chrono::milliseconds(timeout)}); + AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); // Copy output tensors (not storage) so that this can be used in a functional // manner @@ -234,12 +239,15 @@ TORCH_LIBRARY_IMPL(c10d, XPU, m) { c10::intrusive_ptr allgather_into_tensor_coalesced_xpu_( at::TensorList outputs, at::TensorList inputs, - const c10::intrusive_ptr& process_group) { + const c10::intrusive_ptr& process_group, + bool asyncOp) { auto output_vec = outputs.vec(); auto input_vec = inputs.vec(); + AllgatherOptions opts = AllgatherOptions{}; + opts.asyncOp = asyncOp; return process_group->getBackend(c10::DeviceType::XPU) - ->allgather_into_tensor_coalesced(output_vec, input_vec); + ->allgather_into_tensor_coalesced(output_vec, input_vec, opts); } TORCH_LIBRARY_IMPL(c10d, XPU, m) { @@ -249,12 +257,16 @@ TORCH_LIBRARY_IMPL(c10d, XPU, m) { c10::intrusive_ptr allgather_coalesced_xpu_( const std::vector>& output_lists, const at::TensorList& input_list, - const c10::intrusive_ptr& process_group) { + const c10::intrusive_ptr& process_group, + bool asyncOp) { auto input_list_vec = input_list.vec(); + AllgatherOptions opts = AllgatherOptions{}; + opts.asyncOp = asyncOp; return process_group->getBackend(c10::DeviceType::XPU) ->allgather_coalesced( const_cast>&>(output_lists), - input_list_vec); + input_list_vec, + opts); } TORCH_LIBRARY_IMPL(c10d, XPU, m) { @@ -266,13 +278,14 @@ c10::intrusive_ptr gather_xpu_( const at::TensorList& input_tensors, const c10::intrusive_ptr& process_group, int64_t root_rank, + bool asyncOp, int64_t timeout) { auto input_tensors_vec = input_tensors.vec(); return process_group->getBackend(c10::DeviceType::XPU) ->gather( const_cast>&>(output_tensors), input_tensors_vec, - GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); + GatherOptions{root_rank, std::chrono::milliseconds(timeout), asyncOp}); } TORCH_LIBRARY_IMPL(c10d, XPU, m) { @@ -329,6 +342,7 @@ reduce_scatter_xpu_( const std::vector>& input_tensors, const c10::intrusive_ptr& process_group, const c10::intrusive_ptr& reduce_op, + bool asyncOp, int64_t timeout) { auto output_tensors_vec = output_tensors.vec(); auto work = @@ -337,7 +351,7 @@ reduce_scatter_xpu_( output_tensors_vec, const_cast>&>(input_tensors), ReduceScatterOptions{ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); + *reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); return std::tuple, c10::intrusive_ptr>( output_tensors_vec, work); @@ -394,6 +408,7 @@ c10::intrusive_ptr reduce_scatter_tensor_coalesced_xpu_( at::TensorList inputs, const c10::intrusive_ptr& process_group, const c10::intrusive_ptr& reduce_op, + bool asyncOp, int64_t timeout) { auto output_vec = outputs.vec(); auto input_vec = inputs.vec(); @@ -402,7 +417,7 @@ c10::intrusive_ptr reduce_scatter_tensor_coalesced_xpu_( output_vec, input_vec, ReduceScatterOptions{ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); + *reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); } TORCH_LIBRARY_IMPL(c10d, XPU, m) { @@ -415,6 +430,7 @@ c10::intrusive_ptr alltoall_base_xpu_( const c10::intrusive_ptr& process_group, std::vector output_split_sizes, std::vector input_split_sizes, + bool asyncOp, int64_t timeout) { return process_group->getBackend(c10::DeviceType::XPU) ->alltoall_base( @@ -422,7 +438,7 @@ c10::intrusive_ptr alltoall_base_xpu_( input, output_split_sizes, input_split_sizes, - AllToAllOptions{std::chrono::milliseconds(timeout)}); + AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); } TORCH_LIBRARY_IMPL(c10d, XPU, m) { @@ -433,6 +449,7 @@ std::tuple, c10::intrusive_ptr> alltoall_xpu_ const at::TensorList& output_tensors, const at::TensorList& input_tensors, const c10::intrusive_ptr& process_group, + bool asyncOp, int64_t timeout) { auto output_tensors_vec = output_tensors.vec(); auto input_tensors_vec = input_tensors.vec(); @@ -440,7 +457,7 @@ std::tuple, c10::intrusive_ptr> alltoall_xpu_ ->alltoall( output_tensors_vec, input_tensors_vec, - AllToAllOptions{std::chrono::milliseconds(timeout)}); + AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); return std::tuple, c10::intrusive_ptr>( std::move(output_tensors_vec), work); } @@ -494,9 +511,14 @@ c10::intrusive_ptr barrier_xpu( at::Tensor /* unused */, const c10::intrusive_ptr& process_group, const std::vector& device_ids, + bool asyncOp, int64_t timeout) { + BarrierOptions opts = BarrierOptions{}; + opts.device_ids = device_ids; + opts.timeout = std::chrono::milliseconds(timeout); + opts.asyncOp = asyncOp; return process_group->getBackend(c10::DeviceType::XPU) - ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); + ->barrier(opts); } TORCH_LIBRARY_IMPL(c10d, XPU, m) {