From 7fc87c007451c2f8a5954d22e8ec0f42074bdca7 Mon Sep 17 00:00:00 2001 From: AGUL Date: Mon, 12 Sep 2022 22:33:54 -0400 Subject: [PATCH 1/9] 1. Add XPU support to for torch 1.13. 2. Update oneCCL to 2021 gold release. --- .gitmodules | 2 +- CMakeLists.txt | 2 + README.md | 43 ++ cmake/Modules/FindoneCCL.cmake | 9 +- demo/demo.py | 6 +- oneccl_bindings_for_pytorch/__init__.py | 31 +- oneccl_bindings_for_pytorch/csrc/init.cpp | 9 + patches/Update_oneCCL.patch | 13 - requirements.txt | 2 + setup.py | 22 +- src/CMakeLists.txt | 2 +- src/ProcessGroupCCL.cpp | 29 +- src/ProcessGroupCCL.hpp | 21 +- src/ccl_comm_collector.h | 12 +- src/cpu/cpu_ccl.cpp | 116 +++++- src/dispatch_stub.cpp | 161 +++++++- src/dispatch_stub.h | 31 +- src/gpu/dpcpp_ccl.cpp | 477 ++++++++++++++++------ src/utils.cpp | 11 +- src/utils.h | 55 ++- tests/test_c10d_ccl.py | 130 ++++++ third_party/oneCCL | 2 +- tools/setup/env.py | 13 +- version.txt | 2 +- 24 files changed, 990 insertions(+), 211 deletions(-) delete mode 100644 patches/Update_oneCCL.patch create mode 100644 requirements.txt diff --git a/.gitmodules b/.gitmodules index b39ddc1..67e4d7b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "third_party/oneCCL"] path = third_party/oneCCL - url = https://github.com/oneapi-src/oneCCL/ + url = https://github.com/oneapi-src/oneCCL.git diff --git a/CMakeLists.txt b/CMakeLists.txt index d1e1232..9a72c46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,8 @@ set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY true) option(USE_SYSTEM_ONECCL "Use oneCCL library in system" OFF) +option(BUILD_NO_ONECCL_PACKAGE "Build with oneCCL excluded" OFF) + # Find the Torch lib find_package(Torch REQUIRED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") diff --git a/README.md b/README.md index f685660..f1a7abb 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,25 @@ This repository holds PyTorch bindings maintained by Intel for the Intel® oneAP `oneccl_bindings_for_pytorch` module implements PyTorch C10D ProcessGroup API and can be dynamically loaded as external ProcessGroup and only works on Linux platform now. +## Capability + +The table below shows which functions are available for use with CPU / Intel dGPU tensors. + +| | CPU | GPU | +| :--------------- | :---: | :---: | +| `send` | × | × | +| `recv` | × | × | +| `broadcast` | √ | √ | +| `all_reduce` | √ | √ | +| `reduce` | √ | √ | +| `all_gather` | √ | √ | +| `gather` | √ | √ | +| `scatter` | × | × | +| `reduce_scatter` | × | × | +| `all_to_all` | √ | √ | +| `barrier` | √ | √ | + + ## Pytorch API Align We recommend Anaconda as Python package management system. The following is the corresponding branches (tags) of `oneccl_bindings_for_pytorch` and supported Pytorch. @@ -36,6 +55,26 @@ The usage details can be found in the README of corresponding branch. The follow - PyTorch v1.13.0 +## Build Option List + +The following build options are supported in Intel® oneCCL Bindings for PyTorch*. + +| Build Option | Default Value | Description | +| :---------------------------------: | :------------: | :-------------------------------------------------------------------------------------------------: | +| COMPUTE_BACKEND | | Set oneCCL `COMPUTE_BACKEDN`,set to `dpcpp` and use DPC++ Compiler to enable support for Intel XPU | +| CCL_PACKAGE_NAME | oneccl-bind-pt | Set Wheel Name | +| ONECCL_BINDINGS_FOR_PYTORCH_BACKEND | cpu | Set BACKEND | +| CCL_SHA_VERSION | False |add git head sha version to Wheel name | + +## Lunch Option List + +The following lunch options are supported in Intel® oneCCL Bindings for PyTorch*. + +| Lunch Option | Default Value | Description | +| :--------------------------------------: | :-----------: | :-------------------------------------------------------------------: | +| ONECCL_BINDINGS_FOR_PYTORCH_ENV_VERBOSE | 0 | Set verbose level in ONECCL_BINDINGS_FOR_PYTORCH | +| ONECCL_BINDINGS_FOR_PYTORCH_ENV_WAIT_GDB | 0 | Set 1 to force the oneccl_bindings_for_pytorch wait for GDB attaching | + ## Installation ### Install from Source @@ -51,7 +90,10 @@ The usage details can be found in the README of corresponding branch. The follow 2. Install `oneccl_bindings_for_pytorch` ```bash + # for CPU Backend Only python setup.py install + # use DPC++ Compiler to enable support for Intel XPU + COMPUTE_BACKEND=dpcpp python setup.py install ``` ### Install PreBuilt Wheel @@ -69,6 +111,7 @@ Wheel files are avaiable for the following Python versions. ```bash python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu ``` + ## Usage example.py diff --git a/cmake/Modules/FindoneCCL.cmake b/cmake/Modules/FindoneCCL.cmake index 2b4f884..7b153f4 100644 --- a/cmake/Modules/FindoneCCL.cmake +++ b/cmake/Modules/FindoneCCL.cmake @@ -23,7 +23,7 @@ IF (USE_SYSTEM_ONECCL) set(oneapi_root_hint $ENV{INTELONEAPIROOT}) endif() - IF(COMPUTE_BACKEND STREQUAL "dpcpp_level_zero") + IF(COMPUTE_BACKEND STREQUAL "dpcpp") SET(CCL_CONFIGURATION "cpu_gpu_dpcpp") ELSE() SET(CCL_CONFIGURATION "cpu_icc") @@ -34,7 +34,12 @@ IF (USE_SYSTEM_ONECCL) ELSE() SET(ONECCL_ROOT "${PROJECT_SOURCE_DIR}/third_party/oneCCL") - ADD_SUBDIRECTORY(${ONECCL_ROOT}) + IF(BUILD_NO_ONECCL_PACKAGE) + ADD_SUBDIRECTORY(${ONECCL_ROOT} oneCCL EXCLUDE_FROM_ALL) + ELSE() + ADD_SUBDIRECTORY(${ONECCL_ROOT}) + ENDIF() + IF(NOT TARGET ccl) MESSAGE(FATAL_ERROR "Failed to find oneCCL target") ENDIF() diff --git a/demo/demo.py b/demo/demo.py index a93cdf6..05deda8 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -39,7 +39,7 @@ def forward(self, input): device = 'cpu' #"xpu:{}".format(dist.get_rank()) model = Model().to(device) if dist.get_world_size() > 1: - model = DDP(model, device_ids=[device] if device is not 'cpu' else None) + model = DDP(model, device_ids=[device] if (device != 'cpu') else None) optimizer = torch.optim.SGD(model.parameters(), lr=0.001) loss_fn = nn.MSELoss().to(device) @@ -55,7 +55,9 @@ def forward(self, input): L = loss_fn(res, labels) # backward print("Runing backward: {} on device {}".format(i, device)) - L.backward() + with torch.autograd.profiler_legacy.profile(enabled=True, use_xpu=True) as prof: + L.backward() + print(prof) # update print("Runing optim: {} on device {}".format(i, device)) optimizer.step() diff --git a/oneccl_bindings_for_pytorch/__init__.py b/oneccl_bindings_for_pytorch/__init__.py index c7e03de..4928bdf 100644 --- a/oneccl_bindings_for_pytorch/__init__.py +++ b/oneccl_bindings_for_pytorch/__init__.py @@ -3,24 +3,35 @@ import warnings import torch + cwd = os.path.dirname(os.path.abspath(__file__)) -os.environ['CCL_ROOT'] = cwd -FI_PROVIDER_PATH = os.path.join(cwd, "lib/prov") -os.environ['FI_PROVIDER_PATH'] = FI_PROVIDER_PATH if not os.path.exists(os.path.join(cwd, "version.py")): raise RuntimeError("oneccl_bindings_for_pytorch is not installed!") + +def set_env_default(env, key, value): + new_value = env.get(key, value) + env[key] = new_value + + +if os.environ.get("CCL_ROOT") is None: + # set the default oneCCL and MPI library path + set_env_default(os.environ, 'CCL_ROOT', cwd) + + FI_PROVIDER_PATH = os.path.join(cwd, "lib/prov") + set_env_default(os.environ, 'FI_PROVIDER_PATH', FI_PROVIDER_PATH) + + from .version import __version__, git_version from . import _C as ccl_lib if hasattr(torch, 'xpu'): - if torch.xpu.is_available(): - try: - # load the CCL/XPU library - import ctypes - my_c_library = ctypes.cdll.LoadLibrary(os.path.join(cwd, "lib/liboneccl_bindings_for_pytorch_xpu.so")) - except OSError: - print("Warning: Cannot load xpu CCL. CCL doesn't work for XPU device") + try: + # load the CCL/XPU library + import ctypes + my_c_library = ctypes.cdll.LoadLibrary(os.path.join(cwd, "lib/liboneccl_bindings_for_pytorch_xpu.so")) + except OSError: + print("Warning: Cannot load xpu CCL. CCL doesn't work for XPU device") __all__ = [] __all__ += [name for name in dir(ccl_lib) diff --git a/oneccl_bindings_for_pytorch/csrc/init.cpp b/oneccl_bindings_for_pytorch/csrc/init.cpp index f89a44a..1b35b5f 100644 --- a/oneccl_bindings_for_pytorch/csrc/init.cpp +++ b/oneccl_bindings_for_pytorch/csrc/init.cpp @@ -42,10 +42,19 @@ #include #include +#include +#if TORCH_VERSION_MINOR >= 13 #include #include #include #include +#else +#include +#include +#include +#include +#endif + #include namespace py = pybind11; diff --git a/patches/Update_oneCCL.patch b/patches/Update_oneCCL.patch deleted file mode 100644 index 52bdf0f..0000000 --- a/patches/Update_oneCCL.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/third_party/oneCCL/src/CMakeLists.txt b/third_party/oneCCL/src/CMakeLists.txt -index 7b4cff9b..e3587da5 100644 ---- a/third_party/oneCCL/src/CMakeLists.txt -+++ b/third_party/oneCCL/src/CMakeLists.txt -@@ -292,6 +292,8 @@ endif() - # shared library - add_library(ccl SHARED $) - target_include_directories(ccl PUBLIC ${SRC_INCLUDE_DIRS}) -+set(ONEAPI_IMPI_RPATH "'$ORIGIN'") -+set_target_properties(ccl PROPERTIES LINK_FLAGS "-Wl,-rpath,${ONEAPI_IMPI_RPATH}") - target_link_libraries(ccl PRIVATE ${SRC_LINK_LIBS}) - - set_target_properties(ccl PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..083abc6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.10.0 +setuptools~=51.0.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 200b5aa..a1853d7 100644 --- a/setup.py +++ b/setup.py @@ -49,8 +49,8 @@ def create_version(): if sha != 'Unknown': version += '+' + sha[:7] - if os.environ.get("COMPUTE_BACKEND") == "dpcpp_level_zero": - backend = "xpu" + if os.environ.get("COMPUTE_BACKEND") == "dpcpp": + backend = "gpu" else: backend = os.environ.get("ONECCL_BINDINGS_FOR_PYTORCH_BACKEND", "cpu") @@ -78,12 +78,6 @@ def run(self): """ cmake_extensions = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)] for ext in cmake_extensions: - try: - # temp patch the oneCCL code - check_call(["git", "apply", "./patches/Update_oneCCL.patch"], cwd=CWD) - except: - # ignore patch fail - pass self.build_cmake(ext) self.extensions = [ext for ext in self.extensions if not isinstance(ext, CMakeExtension)] @@ -123,7 +117,7 @@ def build_cmake(self, extension: CMakeExtension): runtime = 'gcc' if 'COMPUTE_BACKEND' in os.environ: - if os.environ['COMPUTE_BACKEND'] == 'dpcpp_level_zero': + if os.environ['COMPUTE_BACKEND'] == 'dpcpp': runtime = 'dpcpp' build_options['COMPUTE_BACKEND'] = os.environ['COMPUTE_BACKEND'] import intel_extension_for_pytorch @@ -138,7 +132,7 @@ def build_cmake(self, extension: CMakeExtension): build_args = ['-j', str(os.cpu_count())] check_call(['make', 'oneccl_bindings_for_pytorch'] + build_args, cwd=str(build_dir)) if 'COMPUTE_BACKEND' in os.environ: - if os.environ['COMPUTE_BACKEND'] == 'dpcpp_level_zero': + if os.environ['COMPUTE_BACKEND'] == 'dpcpp': check_call(['make', 'oneccl_bindings_for_pytorch_xpu'] + build_args, cwd=str(build_dir)) check_call(['make', 'install'], cwd=str(build_dir)) @@ -148,14 +142,6 @@ def run(self): import glob import re - if os.path.isfile(os.path.join(CWD, "third_party/oneCCL", "README.md")): - try: - check_call(["git", "reset", "--hard"], cwd=os.path.join(CWD, "third_party/oneCCL")) - except Exception as e: - print("=" * 64 + "\nWARNNING!\n" + "=" * 64) - print(e) - print("=" * 64) - with open('.gitignore', 'r') as f: ignores = f.read() pat = re.compile(r'^#( BEGIN NOT-CLEAN-FILES )?') diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4f02c2d..cd56bac 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,7 +8,7 @@ target_compile_options(oneccl_bindings_for_pytorch PUBLIC -Wall -Wno-sign-compare -Wno-unused-function) -if(COMPUTE_BACKEND STREQUAL "dpcpp_level_zero") +if(COMPUTE_BACKEND STREQUAL "dpcpp") add_subdirectory(./gpu) endif() diff --git a/src/ProcessGroupCCL.cpp b/src/ProcessGroupCCL.cpp index c4ac31f..991d0b2 100644 --- a/src/ProcessGroupCCL.cpp +++ b/src/ProcessGroupCCL.cpp @@ -92,12 +92,13 @@ ProcessGroupCCL::AsyncWorkCCL::AsyncWorkCCL(std::vector> // Profiler: Pass nullptr as profilingTitle to parent constructor to // replace default profiler implementation with async version that reports // correct timestamps for work that is asynchronously executed. - : C10D_Work(rank, opType, profilingTitle, inputTensors), + : C10D_Work(rank, opType, nullptr, inputTensors), outputTensors_(std::move(outputTensors)), future_(createFutureAsOutput(outputTensors)) { -// if (profilingTitle != nullptr) { + if (profilingTitle != nullptr) { // recordAsyncWorkProfilingInfo(profilingTitle, inputTensors); -// } + // TODO: for cpu async profiling repot. + } } c10::intrusive_ptr ProcessGroupCCL::AsyncWorkCCL::getFuture() { @@ -243,7 +244,13 @@ c10::intrusive_ptr ProcessGroupCCL::_allgather_base( at::Tensor& inputTensor, const AllgatherOptions& opts) { - TORCH_CHECK(false, "ProcessGroupCCL does not support _allgather_base"); + std::vector tensor_param; + format_tensors_param(tensor_param, inputTensor); + format_tensors_param(tensor_param, outputTensor); + RECORD_FUNCTION("oneccl_bindings_for_pytorch::_allgather_base", tensor_param); + + auto work = DispatchStub::_allgather_base(outputTensor, inputTensor, opts, *this); + return work; } c10::intrusive_ptr ProcessGroupCCL::allgather_coalesced( @@ -290,6 +297,20 @@ c10::intrusive_ptr ProcessGroupCCL::reduce_scatter( TORCH_CHECK(false, "ProcessGroupCCL does not support reduce_scatter"); } + +c10::intrusive_ptr ProcessGroupCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + std::vector tensor_param; + format_tensors_param(tensor_param, inputTensor); + format_tensors_param(tensor_param, outputTensor); + RECORD_FUNCTION("oneccl_bindings_for_pytorch::_reduce_scatter_base", tensor_param); + + auto work = DispatchStub::_reduce_scatter_base(outputTensor, inputTensor, opts, *this); + return work; +} + c10::intrusive_ptr ProcessGroupCCL::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, diff --git a/src/ProcessGroupCCL.hpp b/src/ProcessGroupCCL.hpp index 86c9263..ef534d0 100644 --- a/src/ProcessGroupCCL.hpp +++ b/src/ProcessGroupCCL.hpp @@ -38,10 +38,18 @@ #include #include +#if TORCH_VERSION_MINOR >= 13 #include #include #include #include +#else +#include +#include +#include +#include +#endif + namespace oneccl_bindings_for_pytorch { struct CCLCommCollector; @@ -94,15 +102,15 @@ class ProcessGroupCCL : public ProcessGroup std::vector result() override; - void finishAsyncWorkCCL(); + virtual void finishAsyncWorkCCL(); void finishAsyncWorkCCLError(std::exception_ptr eptr); - protected: - friend class ProcessGroupCCL; - public: std::string debugName; + + protected: + friend class ProcessGroupCCL; const std::vector> outputTensors_; // The future returned by getFuture. c10::intrusive_ptr future_; @@ -167,6 +175,11 @@ class ProcessGroupCCL : public ProcessGroup std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, diff --git a/src/ccl_comm_collector.h b/src/ccl_comm_collector.h index aa1d98c..7ec667b 100644 --- a/src/ccl_comm_collector.h +++ b/src/ccl_comm_collector.h @@ -40,11 +40,13 @@ namespace oneccl_bindings_for_pytorch { class Comms { public: + // for cpu case explicit Comms(ccl::vector_class &comms) : comms(std::move(comms)), streams{} {} - explicit Comms(ccl::vector_class &comms, std::vector &streams) : - comms(std::move(comms)), streams(std::move(streams)) {} + // for comms with streams + explicit Comms(ccl::vector_class &comms, ccl::vector_class &streams, std::vector &torch_streams) : + comms(std::move(comms)), streams(std::move(streams)), torch_streams(std::move(torch_streams)) {} ~Comms() noexcept(false) {} @@ -56,12 +58,14 @@ class Comms { Comms &operator=(const Comms &) = delete; // Move constructable - Comms(Comms &&other) : comms(std::move(other.comms)), streams(std::move(other.streams)) {} + Comms(Comms &&other) : comms(std::move(other.comms)), streams(std::move(other.streams)), + torch_streams(std::move(other.torch_streams)) {} // Move assignable Comms &operator=(Comms &&other) { std::swap(comms, other.comms); std::swap(streams, other.streams); + std::swap(torch_streams, other.torch_streams); return *this; } @@ -70,6 +74,8 @@ class Comms { ccl::vector_class comms; // The streams used by CCL ccl::vector_class streams; + // one to one mapping the torch streams to the ccl::stream. + std::vector torch_streams; }; struct CCLCommCollector { diff --git a/src/cpu/cpu_ccl.cpp b/src/cpu/cpu_ccl.cpp index 93d1cfd..11cb343 100644 --- a/src/cpu/cpu_ccl.cpp +++ b/src/cpu/cpu_ccl.cpp @@ -134,6 +134,11 @@ class VanillaCPU final: public DispatchStub { const ReduceOptions& opts, ProcessGroupCCL& pg) override; + c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) override; + c10::intrusive_ptr broadcast_(std::vector& tensors, const BroadcastOptions& opts, ProcessGroupCCL& pg) override; @@ -143,6 +148,11 @@ class VanillaCPU final: public DispatchStub { const AllgatherOptions& opts, ProcessGroupCCL& pg) override; + c10::intrusive_ptr _allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl) override; + c10::intrusive_ptr gather_(std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts, @@ -337,6 +347,51 @@ c10::intrusive_ptr VanillaCPU::reduce_(std::vecto return work; } +c10::intrusive_ptr VanillaCPU::_reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) { + const int world_size = pg_ccl.getSize(); + if (inputTensor.numel() != outputTensor.numel() * world_size) { + TORCH_CHECK( + false, + "input tensor must be the same size as output size times world size"); + } + + // just a wrapper to fit the collective interface + auto inputs = std::vector {inputTensor}; + auto outputs = std::vector {outputTensor}; + + c10::intrusive_ptr work; + work = collective( + pg_ccl, + inputs, + outputs, + [=](at::Tensor input, + at::Tensor output, + ccl::reduce_attr attr, + ccl::communicator& comm) { + + ccl::event ret_evt; + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() { + CCL_CHECK(ret_evt = ccl::reduce_scatter(input.data_ptr(), + output.data_ptr(), + (size_t) output.numel(), + cclDatatypes.at(input.scalar_type()), + cclOps.at(opts.reduceOp), + comm)); + }); + return ret_evt; + + }, + c10d::OpType::_REDUCE_SCATTER_BASE, + "oneccl_bindings_for_pytorch::cpu_work::_reduce_scatter_base"); + + work->debugName = std::string("cpu::_reduce_scatter_base"); + enqueue(work); + return work; +} + c10::intrusive_ptr VanillaCPU::broadcast_(std::vector& tensors, const BroadcastOptions &opts, ProcessGroupCCL& pg) { @@ -438,6 +493,51 @@ c10::intrusive_ptr VanillaCPU::allgather_(std::ve return work; } +c10::intrusive_ptr VanillaCPU::_allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl) { + const int world_size = pg_ccl.getSize(); + if (inputTensor.numel() * world_size != outputTensor.numel()) { + TORCH_CHECK(false, "output tensor size must be equal to world_size times input tensor size"); + } + + // just a wrapper to fit the collective interface + auto inputs = std::vector {inputTensor}; + auto outputs = std::vector {outputTensor}; + + c10::intrusive_ptr work; + work = collective( + pg_ccl, + inputs, + outputs, + [=](at::Tensor input, + at::Tensor output, + ccl::allgatherv_attr attr, + ccl::communicator& comm) { + RECORD_FUNCTION("oneccl_bindings_for_pytorch::cpu::_allgather_base", std::vector({input})); + + std::vector recvCounts(world_size, input.numel()); + + ccl::event ret_evt; + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() { + CCL_CHECK(ret_evt = ccl::allgatherv(input.data_ptr(), + (size_t) input.numel(), + output.data_ptr(), + recvCounts, + cclDatatypes.at(input.scalar_type()), + comm, + attr)); + }); + return ret_evt; + }, + c10d::OpType::_ALLGATHER_BASE); + + work->debugName = std::string("cpu::_allgather_base"); + enqueue(work); + return work; +} + c10::intrusive_ptr VanillaCPU::gather_(std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts, @@ -713,17 +813,29 @@ c10::intrusive_ptr VanillaCPU::barrier_(const Bar ProcessGroupCCL& pg) { c10::intrusive_ptr work = c10::make_intrusive(); + + if (pg.ccl_member_->ccl_comms.size() == 0) { + std::vector cpu_devices{at::Device("cpu")}; + const auto key = get_key_from_devs(cpu_devices); + get_ccl_comms(pg, key, cpu_devices); + } + auto& comms_map = pg.ccl_member_->ccl_comms; for(auto iter = comms_map.begin(); iter != comms_map.end(); iter++){ for(size_t i =0 ; i < iter->second->comms.size(); i++){ work->getEvents().emplace_back( call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ - CCL_CHECK(return ccl::barrier(iter->second->comms[i]);); + if (i < iter->second->streams.size()) { + CCL_CHECK(return ccl::barrier(iter->second->comms[i], + iter->second->streams[i]);); + } else { + CCL_CHECK(return ccl::barrier(iter->second->comms[i]);); + } }) ); } } - return work; + return work; } diff --git a/src/dispatch_stub.cpp b/src/dispatch_stub.cpp index 7641b22..f472ae2 100644 --- a/src/dispatch_stub.cpp +++ b/src/dispatch_stub.cpp @@ -29,6 +29,7 @@ * POSSIBILITY OF SUCH DAMAGE. */ +#include #include "env.h" #include "dispatch_stub.h" @@ -57,15 +58,19 @@ static void format_tensors_size(std::ostream& os, const std::vector& vec) { os << "]"; } -static void format_pg_rank(std::ostream& os, const ProcessGroupCCL& pg_ccl) { - os << "[" <allreduce_(tensors, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -90,12 +102,43 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::reduce: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); os << " "; format_tensors_size(os, tensors); std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->reduce_(tensors, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; + return work; + } + + c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) override { + std::stringstream os; + os << "oneccl_bindings_for_pytorch::" << dev_type << "::_reduce_scatter_base: "; + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); + os << " input "; + format_tensors_size(os, inputTensor); + os << " output "; + format_tensors_size(os, outputTensor); + std::cout << os.str() << std::endl; + + auto workStartTime_ = std::chrono::steady_clock::now(); + auto work = hdlr->_reduce_scatter_base_(outputTensor, inputTensor, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -105,14 +148,21 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::allgather: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); os << " input "; format_tensors_size(os, inputTensors); os << " output "; format_tensors_size(os, outputTensors); std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->allgather_(outputTensors, inputTensors, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -122,14 +172,45 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::gather: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); os << " input "; format_tensors_size(os, inputTensors); os << " output "; format_tensors_size(os, outputTensors); std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->gather_(outputTensors, inputTensors, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; + return work; + } + + c10::intrusive_ptr _allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl) override { + std::stringstream os; + os << "oneccl_bindings_for_pytorch::" << dev_type << "::_allgather_base: "; + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); + os << " input "; + format_tensors_size(os, inputTensor); + os << " output "; + format_tensors_size(os, outputTensor); + std::cout << os.str() << std::endl; + + auto workStartTime_ = std::chrono::steady_clock::now(); + auto work = hdlr->_allgather_base_(outputTensor, inputTensor, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -139,14 +220,21 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::scatter: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); os << " input "; format_tensors_size(os, inputTensors); os << " output "; format_tensors_size(os, outputTensors); std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->scatter_(outputTensors, inputTensors, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -155,12 +243,19 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::broadcast: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); os << " "; format_tensors_size(os, tensors); std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->broadcast_(tensors, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -172,7 +267,7 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::alltoall_base: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); os << " input "; format_tensors_size(os, inputTensor); os << " output "; @@ -181,7 +276,14 @@ class DebugCCLStub final: public DispatchStub { os << " outputSplitSizes [" << outputSplitSizes << "]"; std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->alltoall_base_(outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -191,14 +293,21 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::alltoall: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); os << " inputs "; format_tensors_size(os, inputTensors); os << " outputs "; format_tensors_size(os, outputTensors); std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->alltoall_(outputTensors, inputTensors, opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } @@ -206,15 +315,23 @@ class DebugCCLStub final: public DispatchStub { ProcessGroupCCL& pg_ccl) override { std::stringstream os; os << "oneccl_bindings_for_pytorch::" << dev_type << "::barrier: "; - format_pg_rank(os, pg_ccl); + format_pg_rank_with_number(os, pg_ccl, ccl_primitive_number++); std::cout << os.str() << std::endl; + auto workStartTime_ = std::chrono::steady_clock::now(); auto work = hdlr->barrier_(opts, pg_ccl); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = + std::chrono::duration_cast( + currentTimepoint - workStartTime_); + format_time_elapsed(os, timeElapsed); + std::cout << os.str() << std::endl; return work; } private: c10::DeviceType dev_type; DispatchStub* hdlr; + int64_t ccl_primitive_number; }; @@ -264,6 +381,16 @@ c10::intrusive_ptr DispatchStub::reduce(std::vect return get_ccl_stub(dev_type)->reduce_(tensors, opts, pg_ccl); } + +c10::intrusive_ptr DispatchStub::_reduce_scatter_base(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) { + checkSameType(outputTensor, {outputTensor, inputTensor}); + c10::DeviceType dev_type = outputTensor.device().type(); + return get_ccl_stub(dev_type)->_reduce_scatter_base_(outputTensor, inputTensor, opts, pg_ccl); +} + c10::intrusive_ptr DispatchStub::broadcast(std::vector& tensors, const BroadcastOptions& opts, ProcessGroupCCL& pg_ccl) { @@ -282,6 +409,16 @@ c10::intrusive_ptr DispatchStub::allgather(std::v return get_ccl_stub(dev_type)->allgather_(outputTensors, inputTensors, opts, pg_ccl); } +c10::intrusive_ptr DispatchStub::_allgather_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl) { + checkSameType(inputTensor, std::vector{outputTensor}); + c10::DeviceType dev_type = inputTensor.device().type(); + return get_ccl_stub(dev_type)->_allgather_base_(outputTensor, inputTensor, opts, pg_ccl); +} + c10::intrusive_ptr DispatchStub::gather(std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts, diff --git a/src/dispatch_stub.h b/src/dispatch_stub.h index 75293a0..d01cff6 100644 --- a/src/dispatch_stub.h +++ b/src/dispatch_stub.h @@ -55,6 +55,11 @@ class DispatchStub { const ReduceOptions& opts, ProcessGroupCCL& pg_ccl); + static c10::intrusive_ptr _reduce_scatter_base(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl); + static c10::intrusive_ptr broadcast(std::vector& tensors, const BroadcastOptions& opts, ProcessGroupCCL& pg_ccl); @@ -63,7 +68,13 @@ class DispatchStub { std::vector& inputTensors, const AllgatherOptions& opts, ProcessGroupCCL& pg_ccl); - + + static c10::intrusive_ptr _allgather_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl); + static c10::intrusive_ptr gather(std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts, @@ -106,6 +117,15 @@ class DispatchStub { return c10::intrusive_ptr(); } + virtual c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) { + fail(outputTensor.device().type(), "_reduce_scatter_base"); + return c10::intrusive_ptr(); + } + + virtual c10::intrusive_ptr allgather_(std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts, @@ -115,6 +135,15 @@ class DispatchStub { return c10::intrusive_ptr(); } + virtual c10::intrusive_ptr _allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl) { + + fail(inputTensor.device().type(), "_allgather_base"); + return c10::intrusive_ptr(); + } + virtual c10::intrusive_ptr gather_(std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts, diff --git a/src/gpu/dpcpp_ccl.cpp b/src/gpu/dpcpp_ccl.cpp index 55f1c1f..c87c27f 100644 --- a/src/gpu/dpcpp_ccl.cpp +++ b/src/gpu/dpcpp_ccl.cpp @@ -34,10 +34,63 @@ #include #include + + +#define CCL_KERNEL_SUBMIT(cmd, q) \ +({bool profile_barrier = (xpu::is_profiler_enabled()); \ + sycl::event start_evt; \ + if (profile_barrier) { \ + start_evt = q.ext_oneapi_submit_barrier(); \ + } \ + CCL_CHECK(cmd); \ + \ + sycl::event end_evt; \ + if (profile_barrier) { \ + end_evt = q.ext_oneapi_submit_barrier(); \ + xpu::profiler_record("oneccl", start_evt, end_evt); \ + } \ + }) + + namespace oneccl_bindings_for_pytorch { namespace { +// [Sync Streams] Helper that lets the input ccl::stream to wait for the current +// stream. oneCCL communications run on ccl::stream, but input tensors are +// allocated on different streams (i.e., current streams). Communications on +// ccl::stream cannot start before pending input tensor ops on current streams +// finish. Otherwise, ops on two streams might read/write same tensors +// concurrently. +// +// The synchronization above alone is not enough. We also need to make sure +// input tensors are not freed before their usages on ccl::stream finish. This +// can be achieved by calling aten::record_stream, +// which remembers the usage stream (ccl::stream), creates an event on the usage +// stream when GC attempts to free the input tensor, and delays GC until that +// event is done. +void sync_streams( + const std::vector& devices, + const std::vector& ccl_torch_streams) { + for (const auto i : c10::irange(devices.size())) { + c10::impl::VirtualGuardImpl impl(devices[i].type()); + c10::Stream stream = impl.getStream(devices[i]); + c10::Event evt(at::kXPU); + evt.record(stream); + c10::Stream ccl_torch_stream = ccl_torch_streams[i]; + evt.block(ccl_torch_stream); + } +} + +void record_tensor(const at::Tensor& tensor, at::Stream stream) { + tensor.record_stream(stream); +} + +void record_tensor(const std::vector& tensors, at::Stream stream) { + for (auto& tensor : tensors) { + tensor.record_stream(stream); + } +} // Check that all `tensors' have the same device and type and shape and // are distributed across distinct GPUs if these are GPU tensors. @@ -45,7 +98,9 @@ c10::DeviceType check_tensors_properties(const std::vector& tensors) if (tensors.size() == 0) { throw std::runtime_error("Tensor list must be nonempty"); } - auto device_count = xpu::dpcpp::device_count(); + c10::Device device = tensors.front().device(); + c10::impl::VirtualGuardImpl impl(device.type()); + auto device_count = impl.deviceCount(); if (tensors.size() > static_cast(device_count)) { throw std::runtime_error( "Tensor list mustn't be larger than the number of available GPUs"); @@ -107,23 +162,37 @@ Comms& get_ccl_comms(c10d::ProcessGroupCCL& pg_ccl, const std::string& devices_k int local_base_rank = pg_ccl.getRank() * devices.size(); ccl::vector_class> devs_rank; - std::vector ccl_streams; + ccl::vector_class ccl_streams; ccl_streams.reserve(devices.size()); + std::vector torch_streams; + torch_streams.reserve(devices.size()); + + // Create the stream and rank dev mapping + for (size_t i = 0; i < devices.size(); i++) { + c10::impl::VirtualGuardImpl impl(devices[i].type()); + // XPU doesn't support prioritized stream. + c10::Stream stream = impl.getStreamFromGlobalPool(devices[i], /*isHighPriority=*/false); + torch_streams.push_back(stream); - // Use the same queue for computation and communication. - // TODO: IPEX doesn't support multiple queue for now. Copy engine requires a dedicate queue - auto q = xpu::dpcpp::getCurrentDPCPPStream(devices[0].index()).dpcpp_queue(); - ccl_streams.push_back(ccl::create_stream(q)); + auto q = xpu::get_queue_from_stream(stream); + ccl_streams.push_back(ccl::create_stream(q)); - int rank = local_base_rank; - devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + int rank = local_base_rank + i; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + } + // The IPEX use default global context. + // TODO: add get default global context API in IPEX. + c10::impl::VirtualGuardImpl impl(devices[0].type()); + c10::Stream dpcpp_stream = impl.getStream(devices[0]); + auto q = xpu::get_queue_from_stream(dpcpp_stream); auto ctx = ccl::create_context(q.get_context()); - auto dpcpp_comms = ccl::create_communicators(total_rank_size, devs_rank, ctx, pg_ccl.ccl_member_->get_kvs(pg_ccl.getRank(), *pg_ccl.store_)); + // Create ccl::communicators + auto dpcpp_comms = ccl::create_communicators(total_rank_size, devs_rank, ctx, pg_ccl.ccl_member_->get_kvs(pg_ccl.getRank(), *pg_ccl.store_)); - std::shared_ptr dpcpp_comms_ptr = std::make_shared(dpcpp_comms, ccl_streams); // Store the comms to cache + std::shared_ptr dpcpp_comms_ptr = std::make_shared(dpcpp_comms, ccl_streams, torch_streams); pg_ccl.ccl_member_->add_comms(devices_key, dpcpp_comms_ptr); return *dpcpp_comms_ptr.get(); @@ -146,14 +215,27 @@ class XPUWorkCCL : public CollectiveAsyncWorkCCLinputs); + // add SYCL running dependency computation -> communication. + sync_streams(devices, this->comms.torch_streams); + + for (const auto i : c10::irange(this->inputs.size())) { + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + record_tensor(this->inputs[i], this->comms.torch_streams[i]); + } + CollectiveAsyncWorkCCL::run(); - // add SYCL running dependency communication -> computation. }; // No explicitly synchronization. - virtual ~XPUWorkCCL() { - this->rets.clear(); - } + virtual ~XPUWorkCCL() {} // Waiting on the work's on XPU backend bool wait(std::chrono::milliseconds timeout) override { @@ -163,33 +245,27 @@ class XPUWorkCCL : public CollectiveAsyncWorkCCLcomms.torch_streams); + // under the stream guard. Mark the Future completing. + this->AsyncWorkCCL::finishAsyncWorkCCL(); + } private: }; -void execute(c10::intrusive_ptr work) { -// if(work->recordFunctionBeforeCallback_){ -// work->recordFunctionBeforeCallback_(); -// } - try { - work->run(); - } catch (...) { - work->finishAsyncWorkCCLError(std::current_exception()); - return; - } - - work->finishAsyncWorkCCL(); -} - } //namespace anonymous class XPUCCLStubs final: public DispatchStub { public: - XPUCCLStubs() {} + XPUCCLStubs() { + stop_=false; + workerThread_ = std::thread(&XPUCCLStubs::runLoop, this); + } - ~XPUCCLStubs() {} + ~XPUCCLStubs() {destroy();} protected: @@ -202,6 +278,11 @@ class XPUCCLStubs final: public DispatchStub { const ReduceOptions& opts, ProcessGroupCCL& pg_ccl) override; + c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) override; + c10::intrusive_ptr broadcast_(std::vector& tensors, const BroadcastOptions& opts, ProcessGroupCCL& pg_ccl) override; @@ -211,6 +292,11 @@ class XPUCCLStubs final: public DispatchStub { const AllgatherOptions& opts, ProcessGroupCCL& pg_ccl) override; + c10::intrusive_ptr _allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl) override; + c10::intrusive_ptr gather_(std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts, @@ -228,10 +314,17 @@ class XPUCCLStubs final: public DispatchStub { const AllToAllOptions& opts, ProcessGroupCCL& pg) override; - void reset() override { - } - + void destroy(); + void reset() override {} + void runLoop(); + c10::intrusive_ptr execute(c10::intrusive_ptr & work); private: + bool stop_; + std::mutex pgMutex_; + std::thread workerThread_; + std::deque> queue_; + std::condition_variable queueProduceCV_; + std::condition_variable queueConsumeCV_; }; struct RegisterXPUMethods { @@ -251,6 +344,67 @@ void checkGPUTensor(const std::vector& tensors) checkGPUTensor(tensors[0]); } +c10::intrusive_ptr XPUCCLStubs::execute(c10::intrusive_ptr & work){ + try { + work->run(); + } catch (...) { + work->finishAsyncWorkCCLError(std::current_exception()); + return work; + } + // mark the work finished asynchronizely. + work->finishAsyncWorkCCL(); + + // Track the work internal + std::unique_lock lock(pgMutex_); + queue_.push_back(work); + lock.unlock(); + queueProduceCV_.notify_one(); + + return work; +} + +void XPUCCLStubs::destroy() { + std::unique_lock lock(pgMutex_); + queueConsumeCV_.wait(lock, [&] { return queue_.empty(); }); + + // Queue is empty, signal stop + stop_ = true; + + // Release lock to allow threads to terminate + lock.unlock(); + queueProduceCV_.notify_all(); + + // Join the single worker thread + workerThread_.join(); +} + +void XPUCCLStubs::runLoop() { + std::unique_lock lock(pgMutex_); + while (!stop_) { + if (queue_.empty()) { + queueProduceCV_.wait(lock); + continue; + } + + auto work = std::move(queue_.front()); + + queue_.pop_front(); + + lock.unlock(); + queueConsumeCV_.notify_one(); + + try { + work->synchronize(); +// work->finishAsyncWorkCCL(); + + } catch (...) { +// work->finishAsyncWorkCCLError(std::current_exception()); + } + + lock.lock(); + } +} + c10::intrusive_ptr XPUCCLStubs::allreduce_(std::vector& tensors, const AllreduceOptions& opts, ProcessGroupCCL& pg_ccl) { @@ -269,19 +423,18 @@ c10::intrusive_ptr XPUCCLStubs::allreduce_(std::v ccl::event ret_evt; call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ - CCL_CHECK(ret_evt = ccl::allreduce(input.data_ptr(), + CCL_KERNEL_SUBMIT(ret_evt = ccl::allreduce(input.data_ptr(), output.data_ptr(), (size_t) input.numel(), cclDatatypes.at(input.scalar_type()), cclOps.at(opts.reduceOp), comm, stream, - attr);); + attr), stream.get_native()); }); return ret_evt; }, - c10d::OpType::ALLREDUCE, - "oneccl_bindings_for_pytorch::xpu_work::allreduce"); + c10d::OpType::ALLREDUCE); work->debugName = std::string("xpu::allreduce"); execute(work); @@ -308,20 +461,19 @@ c10::intrusive_ptr XPUCCLStubs::reduce_(std::vect ccl::event ret_evt; call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() { - CCL_CHECK(ret_evt = ccl::reduce(input.data_ptr(), + CCL_KERNEL_SUBMIT(ret_evt = ccl::reduce(input.data_ptr(), output.data_ptr(), (size_t) input.numel(), cclDatatypes.at(input.scalar_type()), cclOps.at(opts.reduceOp), root, comm, - stream);); + stream), stream.get_native()); }); return ret_evt; }, - c10d::OpType::REDUCE, - "oneccl_bindings_for_pytorch::xpu_work::reduce"); + c10d::OpType::REDUCE); work->debugName = std::string("xpu::reduce"); execute(work); @@ -329,6 +481,56 @@ c10::intrusive_ptr XPUCCLStubs::reduce_(std::vect return work; } +c10::intrusive_ptr XPUCCLStubs::_reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) { + + checkGPUTensor({outputTensor, inputTensor}); + const int world_size = pg_ccl.getSize(); + if (inputTensor.numel() != outputTensor.numel() * world_size) { + TORCH_CHECK( + false, + "input tensor must be the same size as output size times world size"); + } + + // just a wrapper to fit the collective interface + auto inputs = std::vector {inputTensor}; + auto outputs = std::vector {outputTensor}; + + c10::intrusive_ptr work; + work = collective( + pg_ccl, + inputs, + outputs, + [=](at::Tensor input, + at::Tensor output, + ccl::reduce_attr attr, + ccl::communicator& comm, + ccl::stream& stream) { + RECORD_FUNCTION("oneccl_bindings_for_pytorch::xpu::_reduce_scatter_base", std::vector{input}); + + ccl::event ret_evt; + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() { + CCL_KERNEL_SUBMIT(ret_evt = ccl::reduce_scatter(input.data_ptr(), + output.data_ptr(), + (size_t) output.numel(), + cclDatatypes.at(input.scalar_type()), + cclOps.at(opts.reduceOp), + comm, + stream), stream.get_native()); + }); + return ret_evt; + + }, + c10d::OpType::_REDUCE_SCATTER_BASE); + + work->debugName = std::string("xpu::_reduce_scatter_base"); + execute(work); + + return work; +} + c10::intrusive_ptr XPUCCLStubs::broadcast_(std::vector& tensors, const BroadcastOptions &opts, ProcessGroupCCL& pg_ccl) { @@ -348,18 +550,17 @@ c10::intrusive_ptr XPUCCLStubs::broadcast_(std::v ccl::event ret_evt; call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ - CCL_CHECK(ret_evt = ccl::broadcast(input.data_ptr(), + CCL_KERNEL_SUBMIT(ret_evt = ccl::broadcast(input.data_ptr(), (size_t) input.numel(), cclDatatypes.at(input.scalar_type()), root, comm, stream, - attr)); + attr), stream.get_native()); }); return ret_evt; }, - c10d::OpType::BROADCAST, - "oneccl_bindings_for_pytorch::xpu_work::broadcast"); + c10d::OpType::BROADCAST); work->debugName = std::string("xpu::broadcast"); @@ -401,19 +602,18 @@ c10::intrusive_ptr XPUCCLStubs::allgather_(std::v }); call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() { - CCL_CHECK(ret_evt = ccl::allgatherv(input.data_ptr(), + CCL_KERNEL_SUBMIT(ret_evt = ccl::allgatherv(input.data_ptr(), (size_t) input.numel(), recvBufs, recvCounts, cclDatatypes.at(input.scalar_type()), comm, - stream);); + stream), stream.get_native()); }); return ret_evt; }, - c10d::OpType::ALLGATHER, - "oneccl_bindings_for_pytorch::xpu_work::allgather"); + c10d::OpType::ALLGATHER); work->debugName = std::string("xpu::allgather"); execute(work); @@ -421,6 +621,52 @@ c10::intrusive_ptr XPUCCLStubs::allgather_(std::v return work; } +c10::intrusive_ptr XPUCCLStubs::_allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg_ccl) { + const int world_size = pg_ccl.getSize(); + if (inputTensor.numel() * world_size != outputTensor.numel()) { + TORCH_CHECK(false, "output tensor size must be equal to world_size times input tensor size"); + } + + // just a wrapper to fit the collective interface + auto inputs = std::vector {inputTensor}; + auto outputs = std::vector {outputTensor}; + + c10::intrusive_ptr work; + work = collective( + pg_ccl, + inputs, + outputs, + [=](at::Tensor input, + at::Tensor output, + ccl::allgatherv_attr attr, + ccl::communicator& comm, + ccl::stream& stream) { + RECORD_FUNCTION("oneccl_bindings_for_pytorch::xpu::_allgather_base_", std::vector({input})); + + std::vector recvCounts(world_size, input.numel()); + + ccl::event ret_evt; + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() { + CCL_KERNEL_SUBMIT(ret_evt = ccl::allgatherv(input.data_ptr(), + (size_t) input.numel(), + output.data_ptr(), + recvCounts, + cclDatatypes.at(input.scalar_type()), + comm, + stream), stream.get_native()); + }); + return ret_evt; + }, + c10d::OpType::_ALLGATHER_BASE); + + work->debugName = std::string("xpu::_allgather_base_"); + execute(work); + + return work; +} c10::intrusive_ptr XPUCCLStubs::gather_(std::vector>& outputTensors, std::vector& inputTensors, @@ -479,16 +725,14 @@ c10::intrusive_ptr XPUCCLStubs::gather_(std::vect } ccl::event ret_evt; - CCL_DISPATCH_INTEGRAL_FLOATS_TYPES(input.scalar_type(), "gather", [&] { - call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ - CCL_CHECK(ret_evt = ccl::alltoallv(input.data_ptr(), - sendCounts, - flatOutput.data_ptr(), - recvCounts, - cclDatatypes.at(flatOutput.scalar_type()), - comm, - stream);); - }); + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ + CCL_KERNEL_SUBMIT(ret_evt = ccl::alltoallv(input.data_ptr(), + sendCounts, + flatOutput.data_ptr(), + recvCounts, + cclDatatypes.at(flatOutput.scalar_type()), + comm, + stream), stream.get_native()); }); // TODO : add to post and pre hooks @@ -511,8 +755,7 @@ c10::intrusive_ptr XPUCCLStubs::gather_(std::vect return ret_evt; }, - c10d::OpType::GATHER, - "oneccl_bindings_for_pytorch::xpu_work::gather"); + c10d::OpType::GATHER); work->debugName = std::string("xpu::gather"); execute(work); @@ -550,22 +793,20 @@ c10::intrusive_ptr XPUCCLStubs::alltoall_base_(at ccl::communicator& comm, ccl::stream& stream) { ccl::event ret_evt; - CCL_DISPATCH_INTEGRAL_FLOATS_TYPES(input.scalar_type(), "alltoall_base", [&] { - call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ - CCL_CHECK(ret_evt = ccl::alltoall(input.data_ptr(), - output.data_ptr(), - (size_t)output.numel() / comm.size(), - cclDatatypes.at(output.scalar_type()), - comm, - stream, - attr);); - }); + + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ + CCL_KERNEL_SUBMIT(ret_evt = ccl::alltoall(input.data_ptr(), + output.data_ptr(), + (size_t)output.numel() / comm.size(), + cclDatatypes.at(output.scalar_type()), + comm, + stream, + attr), stream.get_native()); }); return ret_evt; }, - c10d::OpType::ALLTOALL_BASE, - "oneccl_bindings_for_pytorch::xpu_work::alltoall_base"); + c10d::OpType::ALLTOALL_BASE); } else{ // Need alltoallv @@ -579,41 +820,38 @@ c10::intrusive_ptr XPUCCLStubs::alltoall_base_(at ccl::communicator& comm, ccl::stream& stream) { ccl::event ret_evt; - CCL_DISPATCH_INTEGRAL_FLOATS_TYPES(input.scalar_type(), "alltoall_base", [&] { - c10d::checkSplitSizes(inputSplitSizes, input, grp_size); - c10d::checkSplitSizes(outputSplitSizes, output, grp_size); - - std::vector sendCounts(grp_size); - std::vector recvCounts(grp_size); - bool inputSplitsEqual = inputSplitSizes.size() == 0; - bool outputSplitsEqual = outputSplitSizes.size() == 0; - - size_t inLen = input.numel(); - size_t outLen = output.numel(); - if (inLen) inLen /= (inputSplitsEqual ? grp_size : input.size(0)); - if (outLen) outLen /= (outputSplitsEqual ? grp_size : output.size(0)); - - for (int i = 0; i < grp_size; i++) - { - sendCounts[i] = (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); - recvCounts[i] = (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); - } - - call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ - CCL_CHECK(ret_evt = ccl::alltoallv(input.data_ptr(), - sendCounts, - output.data_ptr(), - recvCounts, - cclDatatypes.at(output.scalar_type()), - comm, - stream, - attr);); - }); + c10d::checkSplitSizes(inputSplitSizes, input, grp_size); + c10d::checkSplitSizes(outputSplitSizes, output, grp_size); + + std::vector sendCounts(grp_size); + std::vector recvCounts(grp_size); + bool inputSplitsEqual = inputSplitSizes.size() == 0; + bool outputSplitsEqual = outputSplitSizes.size() == 0; + + size_t inLen = input.numel(); + size_t outLen = output.numel(); + if (inLen) inLen /= (inputSplitsEqual ? grp_size : input.size(0)); + if (outLen) outLen /= (outputSplitsEqual ? grp_size : output.size(0)); + + for (int i = 0; i < grp_size; i++) + { + sendCounts[i] = (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); + recvCounts[i] = (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); + } + + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ + CCL_KERNEL_SUBMIT(ret_evt = ccl::alltoallv(input.data_ptr(), + sendCounts, + output.data_ptr(), + recvCounts, + cclDatatypes.at(output.scalar_type()), + comm, + stream, + attr), stream.get_native()); }); return ret_evt; }, - c10d::OpType::ALLTOALL_BASE, - "oneccl_bindings_for_pytorch::xpu_work::alltoall_base"); + c10d::OpType::ALLTOALL_BASE); } work->debugName = std::string("xpu::alltoall_base"); @@ -639,7 +877,10 @@ c10::intrusive_ptr XPUCCLStubs::alltoall_(std::ve std::vector outputs, ccl::alltoallv_attr attr, ccl::communicator& comm, - ccl::stream& stream) { + ccl::stream& stream, + c10::Stream& torch_stream) { + + c10::OptionalStreamGuard stream_guard(torch_stream); at::Tensor flatInput; at::Tensor flatOutput; @@ -669,17 +910,14 @@ c10::intrusive_ptr XPUCCLStubs::alltoall_(std::ve } ccl::event ret_evt; - CCL_DISPATCH_INTEGRAL_FLOATS_TYPES(flatInput.scalar_type(), "xpu::alltoall", [&] { - call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ - CCL_CHECK(ret_evt = ccl::alltoallv(flatInput.data_ptr(), - sendCounts, - flatOutput.data_ptr(), - recvCounts, - cclDatatypes.at(flatOutput.scalar_type()), - comm, - stream);); - }); - + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ + CCL_KERNEL_SUBMIT(ret_evt = ccl::alltoallv(flatInput.data_ptr(), + sendCounts, + flatOutput.data_ptr(), + recvCounts, + cclDatatypes.at(flatOutput.scalar_type()), + comm, + stream), stream.get_native()); }); if (!isOutputFlat) { @@ -695,8 +933,7 @@ c10::intrusive_ptr XPUCCLStubs::alltoall_(std::ve } return ret_evt; }, - c10d::OpType::ALLTOALL, - "oneccl_bindings_for_pytorch::xpu_work::alltoall"); + c10d::OpType::ALLTOALL); work->debugName = std::string("xpu::alltoall"); execute(work); @@ -706,4 +943,4 @@ c10::intrusive_ptr XPUCCLStubs::alltoall_(std::ve RegisterXPUMethods xpu_register; -} \ No newline at end of file +} diff --git a/src/utils.cpp b/src/utils.cpp index 19c5a52..68895fb 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -43,16 +43,17 @@ std::map cclOps = {ReduceOp::PRODUCT, ccl::reduction::prod}, }; - std::map cclDatatypes = { {at::kByte, ccl::datatype::uint8}, - {at::kChar, ccl::datatype::uint8}, + {at::kChar, ccl::datatype::int8}, + {at::kShort, ccl::datatype::int16}, + {at::kInt, ccl::datatype::int32}, + {at::kLong, ccl::datatype::int64}, + {at::kHalf, ccl::datatype::float16}, + {at::kFloat, ccl::datatype::float32}, {at::kDouble, ccl::datatype::float64}, {at::kBFloat16, ccl::datatype::bfloat16}, - {at::kFloat, ccl::datatype::float32}, - {at::kInt, ccl::datatype::int32}, - {at::kLong, ccl::datatype::int64} }; // Get the key from the list of devices diff --git a/src/utils.h b/src/utils.h index 2696e0e..3318d94 100644 --- a/src/utils.h +++ b/src/utils.h @@ -35,7 +35,14 @@ #include #include #include + +#include +#if TORCH_VERSION_MINOR >= 13 #include +#else +#include +#endif + #include #include "ProcessGroupCCL.hpp" @@ -53,6 +60,7 @@ constexpr uint64_t kSynchronizeBusyWaitMicro = 10; // 50us } \ }while(0) + #define CCL_DISPATCH_INTEGRAL_FLOATS_TYPES(TYPE, NAME, ...) \ [&] { \ const auto& the_type = TYPE; \ @@ -210,19 +218,27 @@ class CollectiveAsyncWorkCCL : public ProcessGroupCCL::AsyncWorkCCL { f(f), comms(comms), attr(attr), inputs(inputs), opTimeout_(timeout) {} void run() override { - using Indices = std::make_index_sequence; - workStartTime_ = std::chrono::steady_clock::now(); - run_wrap_(Indices{}); + if constexpr (num_params == 6) { + workStartTime_ = std::chrono::steady_clock::now(); + run_wrap_(); + } + else{ + using Indices = std::make_index_sequence; + workStartTime_ = std::chrono::steady_clock::now(); + run_wrap_(Indices{}); + } }; virtual ~CollectiveAsyncWorkCCL() { +#if 0 if (!rets.empty()) { std::cerr << "attempted destruction of WorkCCL before work has completed, " << "waiting the request." << std::endl; synchronize(); } +#endif } bool isCompleted() override { @@ -284,8 +300,6 @@ class CollectiveAsyncWorkCCL : public ProcessGroupCCL::AsyncWorkCCL { std::this_thread::sleep_for( std::chrono::microseconds (kSynchronizeBusyWaitMicro)); } - - this->rets.clear(); } void synchronize() override { @@ -324,6 +338,33 @@ class CollectiveAsyncWorkCCL : public ProcessGroupCCL::AsyncWorkCCL { } } + template + typename std::enable_if::value, void>::type run_wrap_() { + if (rets.empty()) { + auto& outputs = outputTensors_; + for (size_t i = 0; i < inputs.size(); i++) { + CCL_CHECK(rets.push_back(f(inputs[i], outputs[i], attr, comms.comms[i], comms.streams[i], comms.torch_streams[i]))); + } + } + else { + // add warning for re run the ccl work + } + } + + template + typename std::enable_if::value, void>::type run_wrap_() { + if (rets.empty()) { + auto& outputs = outputTensors_[0]; + for (size_t i = 0; i < inputs.size(); i++) { + CCL_CHECK(rets.push_back(f(inputs[i], outputs[i], attr, comms.comms[i], comms.streams[i], comms.torch_streams[i]))); + } + } + else { + // add warning for re run the ccl work + } + } + + template ::value, bool> = true> ccl::event& get_event_from_ret_(R& ret) { @@ -376,7 +417,7 @@ c10::intrusive_ptr collective( pre_process pre, post_process post, c10d::OpType op_type, - const char* prof_title) { + const char* prof_title = nullptr) { using traits = function_traits; using attr_t = typename traits::template arg<2>::type; attr_t attr = ccl::create_operation_attr(); @@ -399,7 +440,7 @@ c10::intrusive_ptr collective( std::vector& outputs, fn fun, c10d::OpType op_type, - const char* prof_title) { + const char* prof_title = nullptr) { return collective( pg_ccl, inputs, diff --git a/tests/test_c10d_ccl.py b/tests/test_c10d_ccl.py index 0cc5333..8dba783 100644 --- a/tests/test_c10d_ccl.py +++ b/tests/test_c10d_ccl.py @@ -297,6 +297,70 @@ def test_allgather_basics_xpu(self): def test_allgather_basics_multi_xpu(self): self._test_allgather_basics(lambda t: t.clone().xpu("xpu:{}".format(self.rank))) + def _test_allgather_base_ops(self, fn): + store = c10d.FileStore(self.file_name, self.world_size) + pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + tensor = fn(torch.tensor([self.rank])) + output_t = fn(torch.empty((self.world_size), dtype=tensor.dtype)) + + allgather_base(output_t, tensor) + + # Verification + self.assertEqual(torch.arange(self.world_size), output_t) + + def test_allgather_base_ops(self): + self._test_allgather_base_ops(lambda t: t.clone()) + + @skip_if_no_xpu + def test_allgather_base_ops_xpu(self): + self._test_allgather_base_ops(lambda t: t.clone().xpu()) + + @skip_if_not_multixpu + def test_allgather_basics_multi_xpu(self): + self._test_allgather_basics(lambda t: t.clone().xpu("xpu:{}".format(self.rank))) + + def _test_allgather_base_basics(self, fn): + store = c10d.FileStore(self.file_name, self.world_size) + pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + # anticpate an error + with self.assertRaisesRegex( + RuntimeError, + "output tensor size must be equal to world_size times input tensor size", + ): + tensor = fn(torch.tensor([self.rank])) + output_t = fn(torch.empty((self.world_size + 1), dtype=tensor.dtype)) + # fails the check because output_t is not correctly sized + allgather_base(output_t, tensor) + + # anticpate an error + with self.assertRaisesRegex( + RuntimeError, "Tensors are not equal in data type" + ): + tensor = fn(torch.tensor([self.rank], dtype=torch.float)) + output_t = fn(torch.empty((self.world_size + 1), dtype=torch.long)) + # fails the check because the dtype is different + allgather_base(output_t, tensor) + + def test_allgather_base_basics(self): + self._test_allgather_base_basics(lambda t: t.clone()) + + @skip_if_no_xpu + def test_allgather_base_basics_xpu(self): + self._test_allgather_base_basics(lambda t: t.clone().xpu()) + + @skip_if_not_multixpu + def test_allgather_base_basics_multi_xpu(self): + self._test_allgather_base_basics(lambda t: t.clone().xpu("xpu:{}".format(self.rank))) # alltoall_base def _test_alltoall_base_equal_split_helper(self, fn): @@ -387,6 +451,72 @@ def test_alltoall_basics_xpu(self): def test_alltoall_basics_multi_xpu(self): self._test_all_to_all_helper(lambda t: t.clone().xpu("xpu:{}".format(self.rank))) + def _test_reduce_scatter_base_basics(self, fn): + store = c10d.FileStore(self.file_name, self.world_size) + pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) + + def reduce_scatter_base(output_t, input_t): + work = pg._reduce_scatter_base(output_t, input_t) + work.wait() + + # anticpate an error + with self.assertRaisesRegex( + RuntimeError, + "input tensor must be the same size as output size times world size", + ): + input_t = fn(torch.tensor([self.rank])) + output_t = fn(torch.empty((self.world_size + 1), dtype=input_t.dtype)) + # fails the check because output_t is not correctly sized + reduce_scatter_base(output_t, input_t) + + # anticpate an error + with self.assertRaisesRegex( + RuntimeError, "Tensors are not equal in data type" + ): + tensor = fn(torch.tensor([self.rank], dtype=torch.float)) + output_t = fn(torch.empty((self.world_size + 1), dtype=torch.long)) + # fails the check because the dtype is different + reduce_scatter_base(output_t, tensor) + + def test_reduce_scatter_base_basics(self): + self._test_reduce_scatter_base_basics(lambda t: t.clone()) + + @skip_if_no_xpu + def test_reduce_scatter_base_basics_xpu(self): + self._test_reduce_scatter_base_basics(lambda t: t.clone().xpu()) + + @skip_if_not_multixpu + def test_reduce_scatter_base_basics_multi_xpu(self): + self._test_reduce_scatter_base_basics(lambda t: t.clone().xpu("xpu:{}".format(self.rank))) + + def _test_reduce_scatter_base_ops(self, fn): + store = c10d.FileStore(self.file_name, self.world_size) + pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) + + def reduce_scatter_base(output_t, input_t): + work = pg._reduce_scatter_base(output_t, input_t) + work.wait() + + # reduce_scatter_base is GPU number agnostic. + # Each rank contribute one tensor regardless of GPU counts + output_t = fn(torch.empty([1])) + tensor = fn(torch.arange(self.world_size, dtype=output_t.dtype)) + + reduce_scatter_base(output_t, tensor) + + # Verification + self.assertEqual(output_t[0], self.rank * self.world_size) + + def test_reduce_scatter_base(self): + self._test_reduce_scatter_base_ops(lambda t: t.clone()) + + @skip_if_no_xpu + def test_reduce_scatter_base_xpu(self): + self._test_reduce_scatter_base_ops(lambda t: t.clone().xpu()) + + @skip_if_not_multixpu + def test_reduce_scatter_base_multi_xpu(self): + self._test_reduce_scatter_base_ops(lambda t: t.clone().xpu("xpu:{}".format(self.rank))) if __name__ == '__main__': run_tests() diff --git a/third_party/oneCCL b/third_party/oneCCL index debdc21..bfa1e99 160000 --- a/third_party/oneCCL +++ b/third_party/oneCCL @@ -1 +1 @@ -Subproject commit debdc21aba5230ed0ea945b14d8f61dfeedfb535 +Subproject commit bfa1e9944422111453299e7177dbb103f1f6bf2f diff --git a/tools/setup/env.py b/tools/setup/env.py index 9dfd675..a4a8868 100644 --- a/tools/setup/env.py +++ b/tools/setup/env.py @@ -10,11 +10,16 @@ def get_compiler(runtime): if runtime == 'dpcpp': - cc = shutil.which('icx') - cpp = shutil.which('dpcpp') + c_compiler = 'icx' + cpp_compiler = 'icpx' else: - cc = shutil.which('cc') - cpp = shutil.which('c++') + c_compiler = 'cc' + cpp_compiler = 'c++' + + cc = shutil.which(c_compiler) + cpp = shutil.which(cpp_compiler) + if cpp is None or cc is None: + raise RuntimeError("couldn't find the compiler '{}' or '{}'".format(c_compiler, cpp_compiler)) return cc, cpp diff --git a/version.txt b/version.txt index feaae22..61b11cb 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.13.0 +1.13.100+gpu From 8ba9fe398a700829a6cdfff012c8c8be6724597d Mon Sep 17 00:00:00 2001 From: chengjunlu Date: Fri, 6 Jan 2023 21:13:49 +0800 Subject: [PATCH 2/9] Update README.md Update the readme about the torch-ccl XPU usage and the branch information --- README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f1a7abb..ad31947 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ We recommend Anaconda as Python package management system. The following is the | `torch` | `oneccl_bindings_for_pytorch` | | :-------------------------------------------------------------: | :-----------------------------------------------------------------------: | | `master` | `master` | + | [v1.13](https://github.com/pytorch/pytorch/tree/v1.13) | [ccl_torch1.13.100](https://github.com/intel/torch-ccl/tree/ccl_torch1.13.100) | | [v1.13](https://github.com/pytorch/pytorch/tree/v1.13) | [ccl_torch1.13](https://github.com/intel/torch-ccl/tree/ccl_torch1.13) | | [v1.12.1](https://github.com/pytorch/pytorch/tree/v1.12.1) | [ccl_torch1.12.100](https://github.com/intel/torch-ccl/tree/ccl_torch1.12.100) | | [v1.12.0](https://github.com/pytorch/pytorch/tree/v1.12.0) | [ccl_torch1.12](https://github.com/intel/torch-ccl/tree/ccl_torch1.12) | @@ -64,7 +65,8 @@ The following build options are supported in Intel® oneCCL Bindings for PyTorch | COMPUTE_BACKEND | | Set oneCCL `COMPUTE_BACKEDN`,set to `dpcpp` and use DPC++ Compiler to enable support for Intel XPU | | CCL_PACKAGE_NAME | oneccl-bind-pt | Set Wheel Name | | ONECCL_BINDINGS_FOR_PYTORCH_BACKEND | cpu | Set BACKEND | -| CCL_SHA_VERSION | False |add git head sha version to Wheel name | +| CCL_SHA_VERSION | False | Add git head sha version to Wheel name | +| BUILD_NO_ONECCL_PACKAGE | False | Package the Wheel without oneCCL library | ## Lunch Option List @@ -93,8 +95,14 @@ The following lunch options are supported in Intel® oneCCL Bindings for PyTorch # for CPU Backend Only python setup.py install # use DPC++ Compiler to enable support for Intel XPU - COMPUTE_BACKEND=dpcpp python setup.py install + BUILD_NO_ONECCL_PACKAGE=ON COMPUTE_BACKEND=dpcpp python setup.py install ``` + +**Note:** To run the torch-ccl without oneCCL library installed, Please make sure you have installed oneCCL in the oneAPI basekit from https://www.intel.com/content/www/us/en/developer/tools/oneapi/toolkits.html#base-kit + +```bash +source $basekit_root/ccl/latest/env/vars.sh +``` ### Install PreBuilt Wheel @@ -102,6 +110,7 @@ Wheel files are avaiable for the following Python versions. | Extension Version | Python 3.6 | Python 3.7 | Python 3.8 | Python 3.9 | Python 3.10 | | :---------------: | :--------: | :--------: | :--------: | :--------: | :---------: | +| 1.13.100 | | √ | √ | √ | √ | | 1.13 | | √ | √ | √ | √ | | 1.12.100 | | √ | √ | √ | √ | | 1.12.0 | | √ | √ | √ | √ | From 818caf38e5a1803ee24e7ace74142e39bb079958 Mon Sep 17 00:00:00 2001 From: chengjunlu Date: Tue, 10 Jan 2023 10:25:59 +0800 Subject: [PATCH 3/9] Update README.md Update the GPU wheel install instruction. --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ad31947..fa5893f 100644 --- a/README.md +++ b/README.md @@ -117,10 +117,14 @@ Wheel files are avaiable for the following Python versions. | 1.11.0 | | √ | √ | √ | √ | | 1.10.0 | √ | √ | √ | √ | | +Installation for CPU: ```bash python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu ``` - +Installation for GPU: +```bash +python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-xpu +``` ## Usage example.py From ab75c5b92fca030aca75dc7bd30be5022ce6a3d9 Mon Sep 17 00:00:00 2001 From: chengjunlu Date: Tue, 10 Jan 2023 10:30:15 +0800 Subject: [PATCH 4/9] Update README.md Correct the torch release 1.13.0 source address. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fa5893f..6344420 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,8 @@ We recommend Anaconda as Python package management system. The following is the | `torch` | `oneccl_bindings_for_pytorch` | | :-------------------------------------------------------------: | :-----------------------------------------------------------------------: | | `master` | `master` | - | [v1.13](https://github.com/pytorch/pytorch/tree/v1.13) | [ccl_torch1.13.100](https://github.com/intel/torch-ccl/tree/ccl_torch1.13.100) | - | [v1.13](https://github.com/pytorch/pytorch/tree/v1.13) | [ccl_torch1.13](https://github.com/intel/torch-ccl/tree/ccl_torch1.13) | + | [v1.13.0](https://github.com/pytorch/pytorch/tree/v1.13.0) | [ccl_torch1.13.100](https://github.com/intel/torch-ccl/tree/ccl_torch1.13.100) | + | [v1.13.0](https://github.com/pytorch/pytorch/tree/v1.13.0) | [ccl_torch1.13](https://github.com/intel/torch-ccl/tree/ccl_torch1.13) | | [v1.12.1](https://github.com/pytorch/pytorch/tree/v1.12.1) | [ccl_torch1.12.100](https://github.com/intel/torch-ccl/tree/ccl_torch1.12.100) | | [v1.12.0](https://github.com/pytorch/pytorch/tree/v1.12.0) | [ccl_torch1.12](https://github.com/intel/torch-ccl/tree/ccl_torch1.12) | | [v1.11.0](https://github.com/pytorch/pytorch/tree/v1.11.0) | [ccl_torch1.11](https://github.com/intel/torch-ccl/tree/ccl_torch1.11) | From d616d9b71fdbef169f02133b5ec3f86437962c26 Mon Sep 17 00:00:00 2001 From: liangan1 Date: Tue, 17 Jan 2023 09:31:39 +0800 Subject: [PATCH 5/9] Enable reduce scatter (#58) * Enable _allgather_base (#53) * Enable _allgather_base (#53) * Update README.md * Enable _reduce_scatter_base --- src/ProcessGroupCCL.cpp | 22 ++++++++-------- src/ProcessGroupCCL.hpp | 5 ++++ src/cpu/cpu_ccl.cpp | 57 ++++++++++++++++++++++++++++++++++++++++- src/dispatch_stub.cpp | 16 ++++++++++++ src/dispatch_stub.h | 14 ++++++++++ src/gpu/dpcpp_ccl.cpp | 56 ++++++++++++++++++++++++++++++++++++++++ tests/test_c10d_ccl.py | 24 +++++++++++++++++ 7 files changed, 181 insertions(+), 13 deletions(-) diff --git a/src/ProcessGroupCCL.cpp b/src/ProcessGroupCCL.cpp index 991d0b2..03a217b 100644 --- a/src/ProcessGroupCCL.cpp +++ b/src/ProcessGroupCCL.cpp @@ -248,7 +248,6 @@ c10::intrusive_ptr ProcessGroupCCL::_allgather_base( format_tensors_param(tensor_param, inputTensor); format_tensors_param(tensor_param, outputTensor); RECORD_FUNCTION("oneccl_bindings_for_pytorch::_allgather_base", tensor_param); - auto work = DispatchStub::_allgather_base(outputTensor, inputTensor, opts, *this); return work; } @@ -297,18 +296,17 @@ c10::intrusive_ptr ProcessGroupCCL::reduce_scatter( TORCH_CHECK(false, "ProcessGroupCCL does not support reduce_scatter"); } - c10::intrusive_ptr ProcessGroupCCL::_reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts) { - std::vector tensor_param; - format_tensors_param(tensor_param, inputTensor); - format_tensors_param(tensor_param, outputTensor); - RECORD_FUNCTION("oneccl_bindings_for_pytorch::_reduce_scatter_base", tensor_param); - - auto work = DispatchStub::_reduce_scatter_base(outputTensor, inputTensor, opts, *this); - return work; + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) +{ + std::vector tensor_param; + format_tensors_param(tensor_param, inputTensor); + format_tensors_param(tensor_param, outputTensor); + RECORD_FUNCTION("oneccl_bindings_for_pytorch::_reduce_scatter_base", tensor_param); + auto work = DispatchStub::_reduce_scatter_base(outputTensor, inputTensor, opts, *this); + return work; } c10::intrusive_ptr ProcessGroupCCL::alltoall_base( diff --git a/src/ProcessGroupCCL.hpp b/src/ProcessGroupCCL.hpp index ef534d0..2bbda33 100644 --- a/src/ProcessGroupCCL.hpp +++ b/src/ProcessGroupCCL.hpp @@ -174,6 +174,11 @@ class ProcessGroupCCL : public ProcessGroup std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr _reduce_scatter_base( at::Tensor& outputBuffer, diff --git a/src/cpu/cpu_ccl.cpp b/src/cpu/cpu_ccl.cpp index 11cb343..85404c4 100644 --- a/src/cpu/cpu_ccl.cpp +++ b/src/cpu/cpu_ccl.cpp @@ -157,6 +157,11 @@ class VanillaCPU final: public DispatchStub { std::vector& inputTensors, const GatherOptions& opts, ProcessGroupCCL& pg) override; + + c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg) override; c10::intrusive_ptr alltoall_base_(at::Tensor& outputTensor, at::Tensor& inputTensor, @@ -532,7 +537,6 @@ c10::intrusive_ptr VanillaCPU::_allgather_base_(a return ret_evt; }, c10d::OpType::_ALLGATHER_BASE); - work->debugName = std::string("cpu::_allgather_base"); enqueue(work); return work; @@ -630,6 +634,57 @@ c10::intrusive_ptr VanillaCPU::gather_(std::vecto return work; } +c10::intrusive_ptr VanillaCPU::_reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg) { + + checkSingleTensorHelper(inputTensor); + checkSingleTensorHelper(outputTensor); + int size = pg.getSize(); + if (inputTensor.dtype() != outputTensor.dtype()) { + TORCH_CHECK(false, "output tensor must have the same type as input tensor"); + } + + if (outputTensor.numel() * size != inputTensor.numel()) { + TORCH_CHECK( + false, + "input tensor size must be equal to world_size times output tensor size"); + } + std::vector inputs{inputTensor}; + std::vector outputs{outputTensor}; + + c10::intrusive_ptr work; + work = collective( + pg, + inputs, + outputs, + [=](at::Tensor input, + at::Tensor output, + ccl::reduce_scatter_attr attr, + ccl::communicator& comm) { + ccl::event ret_evt; + + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ + CCL_CHECK(ret_evt = ccl::reduce_scatter(input.data_ptr(), + output.data_ptr(), + size_t(input.numel()/size), + cclDatatypes.at(input.scalar_type()), + cclOps.at(opts.reduceOp), + comm, + attr);); + }); + + return ret_evt; + }, + c10d::OpType::_REDUCE_SCATTER_BASE, + "oneccl_bindings_for_pytorch::cpu_work::_reduce_scatter_base"); + + work->debugName = std::string("cpu::_reduce_scatter_base"); + enqueue(work); + return work; +} + c10::intrusive_ptr VanillaCPU::alltoall_base_(at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, diff --git a/src/dispatch_stub.cpp b/src/dispatch_stub.cpp index f472ae2..ba36952 100644 --- a/src/dispatch_stub.cpp +++ b/src/dispatch_stub.cpp @@ -258,6 +258,14 @@ class DebugCCLStub final: public DispatchStub { std::cout << os.str() << std::endl; return work; } + + c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) { + c10::DeviceType dev_type = inputTensor.device().type(); + return get_ccl_stub(dev_type)->_reduce_scatter_base_(outputTensor, inputTensor, opts, pg_ccl); + } c10::intrusive_ptr alltoall_base_(at::Tensor& outputTensor, at::Tensor& inputTensor, @@ -439,6 +447,14 @@ c10::intrusive_ptr DispatchStub::scatter(std::vec return get_ccl_stub(dev_type)->scatter_(outputTensors, inputTensors, opts, pg_ccl); } +c10::intrusive_ptr DispatchStub::_reduce_scatter_base(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) { + c10::DeviceType dev_type = inputTensor.device().type(); + return get_ccl_stub(dev_type)->_reduce_scatter_base_(outputTensor, inputTensor, opts, pg_ccl); +} + c10::intrusive_ptr DispatchStub::alltoall_base(at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, diff --git a/src/dispatch_stub.h b/src/dispatch_stub.h index d01cff6..dad4584 100644 --- a/src/dispatch_stub.h +++ b/src/dispatch_stub.h @@ -85,6 +85,11 @@ class DispatchStub { const ScatterOptions& opts, ProcessGroupCCL& pg_ccl); + static c10::intrusive_ptr _reduce_scatter_base(at::Tensor & outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl); + static c10::intrusive_ptr alltoall_base(at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -159,6 +164,15 @@ class DispatchStub { fail(outputTensors[0].device().type(), "scatter"); return c10::intrusive_ptr(); } + + virtual c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg_ccl) { + + fail(inputTensor.device().type(), "_reduce_scatter_base"); + return c10::intrusive_ptr(); + } virtual c10::intrusive_ptr broadcast_(std::vector& tensors, const BroadcastOptions& opts, diff --git a/src/gpu/dpcpp_ccl.cpp b/src/gpu/dpcpp_ccl.cpp index c87c27f..1240ae9 100644 --- a/src/gpu/dpcpp_ccl.cpp +++ b/src/gpu/dpcpp_ccl.cpp @@ -301,6 +301,16 @@ class XPUCCLStubs final: public DispatchStub { std::vector& inputTensors, const GatherOptions& opts, ProcessGroupCCL& pg) override; + + c10::intrusive_ptr _allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts, + ProcessGroupCCL& pg) override; + + c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg) override; c10::intrusive_ptr alltoall_(std::vector& outputTensors, std::vector& inputTensors, @@ -763,6 +773,52 @@ c10::intrusive_ptr XPUCCLStubs::gather_(std::vect return work; } +c10::intrusive_ptr XPUCCLStubs::_allgather_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg) { + + checkSingleTensorHelper(inputTensor); + checkSingleTensorHelper(outputTensor); + int size = pg.getSize(); + if (inputTensor.dtype() != outputTensor.dtype()) { + TORCH_CHECK(false, "output tensor must have the same type as input tensor"); + } + + c10::intrusive_ptr work; + TORCH_CHECK( + false, + "_allgather_base_ is not supported in ProcessGroupCCL for GPU now"); + return work; +} + +c10::intrusive_ptr XPUCCLStubs::_reduce_scatter_base_(at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts, + ProcessGroupCCL& pg) { + + checkSingleTensorHelper(inputTensor); + checkSingleTensorHelper(outputTensor); + int size = pg.getSize(); + if (inputTensor.dtype() != outputTensor.dtype()) { + TORCH_CHECK(false, "output tensor must have the same type as input tensor"); + } + + if (outputTensor.numel() * size != inputTensor.numel()) { + TORCH_CHECK( + false, + "input tensor size must be equal to world_size times output tensor size"); + } + std::vector inputs{inputTensor}; + std::vector outputs{outputTensor}; + + c10::intrusive_ptr work; + TORCH_CHECK( + false, + "_reduce_scatter_base_ is not supported in ProcessGroupCCL for GPU now"); + return work; +} + c10::intrusive_ptr XPUCCLStubs::alltoall_base_(at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, diff --git a/tests/test_c10d_ccl.py b/tests/test_c10d_ccl.py index 8dba783..824cea6 100644 --- a/tests/test_c10d_ccl.py +++ b/tests/test_c10d_ccl.py @@ -264,6 +264,30 @@ def test_gather_basics_xpu(self): @skip_if_not_multixpu def test_gather_basics_multi_xpu(self): self._test_gather_basics(lambda t: t.clone().xpu("xpu:{}".format(self.rank))) + + def test_allgather_base_ops(self): + store = c10d.FileStore(self.file_name, self.world_size) + pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + tensor = torch.tensor([self.rank]) + output_t = torch.empty((self.world_size), dtype=tensor.dtype) + allgather_base(output_t, tensor) + + def test_reduce_scatter_base_ops(self): + store = c10d.FileStore(self.file_name, self.world_size) + pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) + + def reduce_scatter_base(output_t, input_t): + work = pg._reduce_scatter_base(output_t, input_t) + work.wait() + + tensor = torch.arange(self.world_size) + output_t = torch.tensor(self.rank, dtype=tensor.dtype) + reduce_scatter_base(output_t, tensor) def _test_allgather_basics(self, fn): store = c10d.FileStore(self.file_name, self.world_size) From 03fe4a15eab3f12740a4583807c37bc89c1ba8cf Mon Sep 17 00:00:00 2001 From: liangan1 Date: Wed, 18 Jan 2023 15:43:36 +0800 Subject: [PATCH 6/9] Fix merge confict issue (#63) --- src/ProcessGroupCCL.hpp | 5 ----- src/cpu/cpu_ccl.cpp | 50 ----------------------------------------- src/dispatch_stub.cpp | 17 -------------- src/dispatch_stub.h | 14 ------------ tests/test_c10d_ccl.py | 21 +---------------- 5 files changed, 1 insertion(+), 106 deletions(-) diff --git a/src/ProcessGroupCCL.hpp b/src/ProcessGroupCCL.hpp index 2bbda33..9155d9a 100644 --- a/src/ProcessGroupCCL.hpp +++ b/src/ProcessGroupCCL.hpp @@ -175,11 +175,6 @@ class ProcessGroupCCL : public ProcessGroup std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - c10::intrusive_ptr _reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - c10::intrusive_ptr _reduce_scatter_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, diff --git a/src/cpu/cpu_ccl.cpp b/src/cpu/cpu_ccl.cpp index 85404c4..eafae57 100644 --- a/src/cpu/cpu_ccl.cpp +++ b/src/cpu/cpu_ccl.cpp @@ -158,11 +158,6 @@ class VanillaCPU final: public DispatchStub { const GatherOptions& opts, ProcessGroupCCL& pg) override; - c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts, - ProcessGroupCCL& pg) override; - c10::intrusive_ptr alltoall_base_(at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -352,51 +347,6 @@ c10::intrusive_ptr VanillaCPU::reduce_(std::vecto return work; } -c10::intrusive_ptr VanillaCPU::_reduce_scatter_base_(at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts, - ProcessGroupCCL& pg_ccl) { - const int world_size = pg_ccl.getSize(); - if (inputTensor.numel() != outputTensor.numel() * world_size) { - TORCH_CHECK( - false, - "input tensor must be the same size as output size times world size"); - } - - // just a wrapper to fit the collective interface - auto inputs = std::vector {inputTensor}; - auto outputs = std::vector {outputTensor}; - - c10::intrusive_ptr work; - work = collective( - pg_ccl, - inputs, - outputs, - [=](at::Tensor input, - at::Tensor output, - ccl::reduce_attr attr, - ccl::communicator& comm) { - - ccl::event ret_evt; - call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&]() { - CCL_CHECK(ret_evt = ccl::reduce_scatter(input.data_ptr(), - output.data_ptr(), - (size_t) output.numel(), - cclDatatypes.at(input.scalar_type()), - cclOps.at(opts.reduceOp), - comm)); - }); - return ret_evt; - - }, - c10d::OpType::_REDUCE_SCATTER_BASE, - "oneccl_bindings_for_pytorch::cpu_work::_reduce_scatter_base"); - - work->debugName = std::string("cpu::_reduce_scatter_base"); - enqueue(work); - return work; -} - c10::intrusive_ptr VanillaCPU::broadcast_(std::vector& tensors, const BroadcastOptions &opts, ProcessGroupCCL& pg) { diff --git a/src/dispatch_stub.cpp b/src/dispatch_stub.cpp index ba36952..533c523 100644 --- a/src/dispatch_stub.cpp +++ b/src/dispatch_stub.cpp @@ -259,14 +259,6 @@ class DebugCCLStub final: public DispatchStub { return work; } - c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts, - ProcessGroupCCL& pg_ccl) { - c10::DeviceType dev_type = inputTensor.device().type(); - return get_ccl_stub(dev_type)->_reduce_scatter_base_(outputTensor, inputTensor, opts, pg_ccl); - } - c10::intrusive_ptr alltoall_base_(at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -390,15 +382,6 @@ c10::intrusive_ptr DispatchStub::reduce(std::vect } -c10::intrusive_ptr DispatchStub::_reduce_scatter_base(at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts, - ProcessGroupCCL& pg_ccl) { - checkSameType(outputTensor, {outputTensor, inputTensor}); - c10::DeviceType dev_type = outputTensor.device().type(); - return get_ccl_stub(dev_type)->_reduce_scatter_base_(outputTensor, inputTensor, opts, pg_ccl); -} - c10::intrusive_ptr DispatchStub::broadcast(std::vector& tensors, const BroadcastOptions& opts, ProcessGroupCCL& pg_ccl) { diff --git a/src/dispatch_stub.h b/src/dispatch_stub.h index dad4584..7156b60 100644 --- a/src/dispatch_stub.h +++ b/src/dispatch_stub.h @@ -55,11 +55,6 @@ class DispatchStub { const ReduceOptions& opts, ProcessGroupCCL& pg_ccl); - static c10::intrusive_ptr _reduce_scatter_base(at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts, - ProcessGroupCCL& pg_ccl); - static c10::intrusive_ptr broadcast(std::vector& tensors, const BroadcastOptions& opts, ProcessGroupCCL& pg_ccl); @@ -122,15 +117,6 @@ class DispatchStub { return c10::intrusive_ptr(); } - virtual c10::intrusive_ptr _reduce_scatter_base_(at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts, - ProcessGroupCCL& pg_ccl) { - fail(outputTensor.device().type(), "_reduce_scatter_base"); - return c10::intrusive_ptr(); - } - - virtual c10::intrusive_ptr allgather_(std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts, diff --git a/tests/test_c10d_ccl.py b/tests/test_c10d_ccl.py index 824cea6..3efbf7a 100644 --- a/tests/test_c10d_ccl.py +++ b/tests/test_c10d_ccl.py @@ -482,26 +482,7 @@ def _test_reduce_scatter_base_basics(self, fn): def reduce_scatter_base(output_t, input_t): work = pg._reduce_scatter_base(output_t, input_t) work.wait() - - # anticpate an error - with self.assertRaisesRegex( - RuntimeError, - "input tensor must be the same size as output size times world size", - ): - input_t = fn(torch.tensor([self.rank])) - output_t = fn(torch.empty((self.world_size + 1), dtype=input_t.dtype)) - # fails the check because output_t is not correctly sized - reduce_scatter_base(output_t, input_t) - - # anticpate an error - with self.assertRaisesRegex( - RuntimeError, "Tensors are not equal in data type" - ): - tensor = fn(torch.tensor([self.rank], dtype=torch.float)) - output_t = fn(torch.empty((self.world_size + 1), dtype=torch.long)) - # fails the check because the dtype is different - reduce_scatter_base(output_t, tensor) - + def test_reduce_scatter_base_basics(self): self._test_reduce_scatter_base_basics(lambda t: t.clone()) From 686c83818a2ecc5e12b9e54247aa86f2633328bc Mon Sep 17 00:00:00 2001 From: zhuhong61 <95205772+zhuhong61@users.noreply.github.com> Date: Tue, 7 Feb 2023 14:43:58 +0800 Subject: [PATCH 7/9] Cancel version restriction of setuptools to avoid secuirty vulnerability issue (#66) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 083abc6..d3dc24c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch>=1.10.0 -setuptools~=51.0.0 \ No newline at end of file +setuptools From b26412239768dd092e520d5befe7de2f43c5ab28 Mon Sep 17 00:00:00 2001 From: zhuhong61 <95205772+zhuhong61@users.noreply.github.com> Date: Tue, 4 Apr 2023 10:10:19 +0800 Subject: [PATCH 8/9] Correct the typos in README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6344420..008e9b3 100644 --- a/README.md +++ b/README.md @@ -68,11 +68,11 @@ The following build options are supported in Intel® oneCCL Bindings for PyTorch | CCL_SHA_VERSION | False | Add git head sha version to Wheel name | | BUILD_NO_ONECCL_PACKAGE | False | Package the Wheel without oneCCL library | -## Lunch Option List +## Launch Option List -The following lunch options are supported in Intel® oneCCL Bindings for PyTorch*. +The following launch options are supported in Intel® oneCCL Bindings for PyTorch*. -| Lunch Option | Default Value | Description | +| Launch Option | Default Value | Description | | :--------------------------------------: | :-----------: | :-------------------------------------------------------------------: | | ONECCL_BINDINGS_FOR_PYTORCH_ENV_VERBOSE | 0 | Set verbose level in ONECCL_BINDINGS_FOR_PYTORCH | | ONECCL_BINDINGS_FOR_PYTORCH_ENV_WAIT_GDB | 0 | Set 1 to force the oneccl_bindings_for_pytorch wait for GDB attaching | From 77f3c2a2c26bbb9ae43bc3c4a53bb0ac9742afdc Mon Sep 17 00:00:00 2001 From: Areg Melik-Adamyan Date: Mon, 22 May 2023 19:59:57 -0500 Subject: [PATCH 9/9] Update README.md Fixed typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 008e9b3..f94cf58 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ The following build options are supported in Intel® oneCCL Bindings for PyTorch | Build Option | Default Value | Description | | :---------------------------------: | :------------: | :-------------------------------------------------------------------------------------------------: | -| COMPUTE_BACKEND | | Set oneCCL `COMPUTE_BACKEDN`,set to `dpcpp` and use DPC++ Compiler to enable support for Intel XPU | +| COMPUTE_BACKEND | | Set oneCCL `COMPUTE_BACKEND`,set to `dpcpp` and use DPC++ Compiler to enable support for Intel XPU | | CCL_PACKAGE_NAME | oneccl-bind-pt | Set Wheel Name | | ONECCL_BINDINGS_FOR_PYTORCH_BACKEND | cpu | Set BACKEND | | CCL_SHA_VERSION | False | Add git head sha version to Wheel name |