From 999879e0e670c2dc817986b1c922dd57e3bbc40b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 2 Jan 2025 15:31:18 -0800 Subject: [PATCH 1/5] Do not create multiple cublas handle Signed-off-by: Przemek Tredak --- .../common/gemm/cublaslt_gemm.cu | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ef7cdc0af9..2a03913b92 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -50,6 +50,27 @@ uint32_t _getAlignment(uintptr_t address) { namespace transformer_engine { +class cublasHandleManager { + public: + static cublasHandleManager &Instance() { + static thread_local cublasHandleManager instance; + return instance; + } + + cublasLtHandle_t GetHandle() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { + NVTE_CHECK_CUBLAS(cublasLtCreate(&handle_)); + }); + return handle_; + } + + ~cublasHandleManager() {} + + private: + cublasLtHandle_t handle_ = nullptr; +}; + void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, @@ -98,8 +119,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, float zero = 0.0; float beta = (accumulate) ? one : zero; - cublasLtHandle_t handle; - NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; From 7b868b0c30cf9da417cf8b273f9ac5d6c5aa46a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 23:38:04 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 2a03913b92..009405729d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -59,9 +59,7 @@ class cublasHandleManager { cublasLtHandle_t GetHandle() { static thread_local std::once_flag flag; - std::call_once(flag, [&] { - NVTE_CHECK_CUBLAS(cublasLtCreate(&handle_)); - }); + std::call_once(flag, [&] { NVTE_CHECK_CUBLAS(cublasLtCreate(&handle_)); }); return handle_; } From 529aec2ae5869591040a7d53288c8929bf754af9 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 6 Jan 2025 14:50:28 -0800 Subject: [PATCH 3/5] Fix for multiple GPUs per thread Signed-off-by: Przemek Tredak --- transformer_engine/common/cudnn_utils.cpp | 10 +++- transformer_engine/common/cudnn_utils.h | 30 +++------- .../common/fused_attn/fused_attn.cpp | 12 ++-- .../common/gemm/cublaslt_gemm.cu | 24 ++------ .../common/normalization/common.cpp | 2 +- .../common/util/handle_manager.h | 55 +++++++++++++++++++ 6 files changed, 86 insertions(+), 47 deletions(-) create mode 100644 transformer_engine/common/util/handle_manager.h diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index 80d2707315..f44edffe66 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -58,9 +58,17 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) } void nvte_cudnn_handle_init() { - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } +namespace detail { + +void CreateCuDNNHandle(cudnnHandle_t* handle) { + NVTE_CHECK_CUDNN(cudnnCreate(handle)); +} + +} // namespace detail + } // namespace transformer_engine namespace cudnn_frontend { diff --git a/transformer_engine/common/cudnn_utils.h b/transformer_engine/common/cudnn_utils.h index eb19b9ddb2..0016ad7f55 100644 --- a/transformer_engine/common/cudnn_utils.h +++ b/transformer_engine/common/cudnn_utils.h @@ -10,37 +10,25 @@ #include #include #include - -#include -#include +#include #include "transformer_engine/transformer_engine.h" +#include "util/handle_manager.h" namespace transformer_engine { -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); +namespace detail { -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); +void CreateCuDNNHandle(cudnnHandle_t* handle); -class cudnnExecutionPlanManager { - public: - static cudnnExecutionPlanManager &Instance() { - static thread_local cudnnExecutionPlanManager instance; - return instance; - } +} // namespace detail - cudnnHandle_t GetCudnnHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { cudnnCreate(&handle_); }); - return handle_; - } +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); - ~cudnnExecutionPlanManager() {} +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); - private: - cudnnHandle_t handle_ = nullptr; -}; +using cudnnExecutionPlanManager = detail::HandleManager; } // namespace transformer_engine -#endif +#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5d3e1d6097..7f871c3d3a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -304,7 +304,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -386,7 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -486,7 +486,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const t_kv = input_KV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -577,7 +577,7 @@ void nvte_fused_attn_bwd_kvpacked( t_kv = input_KV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -674,7 +674,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso t_kv = input_K->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); @@ -761,7 +761,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso t_kv = input_K->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 009405729d..bbd2b2ed50 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -15,6 +15,7 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/handle_manager.h" namespace { @@ -46,28 +47,15 @@ uint32_t _getAlignment(uintptr_t address) { } } +inline void CreateCublasHandle(cublasLtHandle_t* handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + } // namespace namespace transformer_engine { -class cublasHandleManager { - public: - static cublasHandleManager &Instance() { - static thread_local cublasHandleManager instance; - return instance; - } - - cublasLtHandle_t GetHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { NVTE_CHECK_CUBLAS(cublasLtCreate(&handle_)); }); - return handle_; - } - - ~cublasHandleManager() {} - - private: - cublasLtHandle_t handle_ = nullptr; -}; +using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 89e2e9feec..c9b7e643c6 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -190,7 +190,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); - _handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + _handle = cudnnExecutionPlanManager::Instance().GetHandle(); _graph.set_io_data_type(get_cudnn_fe_dtype(itype)) .set_intermediate_data_type(get_cudnn_fe_dtype(ctype)) diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h new file mode 100644 index 0000000000..1bdf3e92d2 --- /dev/null +++ b/transformer_engine/common/util/handle_manager.h @@ -0,0 +1,55 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ +#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ + +#include +#include +#include +#include "logging.h" +#include "cuda_runtime.h" + +namespace transformer_engine::detail { + +template +class HandleManager { + public: + static HandleManager &Instance() { + static thread_local HandleManager instance; + return instance; + } + + Handle GetHandle() { + static std::vector flags(handles_.size()); + int device_id = cuda::current_device(); + NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); + auto init = [&]() { + Create(&(handles_[device_id])); + }; + std::call_once(flags[device_id], init); + return handles_[device_id]; + } + + ~HandleManager() { + if (Destroy != nullptr) { + for (auto& handle : handles_) { + Destroy(handle); + } + } + } + + private: + HandleManager() : handles_(cuda::num_devices(), nullptr) {} + + std::vector handles_ = nullptr; +}; + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ From 44c7d70ae7cc80700bbb3b5b457949809d4ee551 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 22:51:21 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cudnn_utils.cpp | 10 +++------- transformer_engine/common/gemm/cublaslt_gemm.cu | 4 ++-- transformer_engine/common/util/handle_manager.h | 16 +++++++--------- 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index f44edffe66..eaf6de680a 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -57,15 +57,11 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) } } -void nvte_cudnn_handle_init() { - auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); -} +void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } namespace detail { -void CreateCuDNNHandle(cudnnHandle_t* handle) { - NVTE_CHECK_CUDNN(cudnnCreate(handle)); -} +void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); } } // namespace detail @@ -76,6 +72,6 @@ namespace cudnn_frontend { // This is needed to define the symbol `cudnn_dlhandle` // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING // to enable dynamic loading. -void *cudnn_dlhandle = nullptr; +void* cudnn_dlhandle = nullptr; } // namespace cudnn_frontend diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index bbd2b2ed50..0a78ded306 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -14,8 +14,8 @@ #include #include "../common.h" -#include "../util/logging.h" #include "../util/handle_manager.h" +#include "../util/logging.h" namespace { @@ -47,7 +47,7 @@ uint32_t _getAlignment(uintptr_t address) { } } -inline void CreateCublasHandle(cublasLtHandle_t* handle) { +inline void CreateCublasHandle(cublasLtHandle_t *handle) { NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); } diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h index 1bdf3e92d2..27666c6948 100644 --- a/transformer_engine/common/util/handle_manager.h +++ b/transformer_engine/common/util/handle_manager.h @@ -7,20 +7,20 @@ #ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ #define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ +#include + #include #include -#include -#include "logging.h" + #include "cuda_runtime.h" +#include "logging.h" namespace transformer_engine::detail { -template +template class HandleManager { public: - static HandleManager &Instance() { + static HandleManager& Instance() { static thread_local HandleManager instance; return instance; } @@ -29,9 +29,7 @@ class HandleManager { static std::vector flags(handles_.size()); int device_id = cuda::current_device(); NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); - auto init = [&]() { - Create(&(handles_[device_id])); - }; + auto init = [&]() { Create(&(handles_[device_id])); }; std::call_once(flags[device_id], init); return handles_[device_id]; } From e62a1282bacd1963496a6da7d3bab0f72d903495 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 6 Jan 2025 15:55:19 -0800 Subject: [PATCH 5/5] Fix multithreaded execution Signed-off-by: Przemek Tredak --- transformer_engine/common/util/handle_manager.h | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h index 27666c6948..adb2f55587 100644 --- a/transformer_engine/common/util/handle_manager.h +++ b/transformer_engine/common/util/handle_manager.h @@ -7,9 +7,6 @@ #ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ #define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ -#include - -#include #include #include "cuda_runtime.h" @@ -26,11 +23,13 @@ class HandleManager { } Handle GetHandle() { - static std::vector flags(handles_.size()); - int device_id = cuda::current_device(); + static thread_local std::vector initialized(handles_.size(), false); + const int device_id = cuda::current_device(); NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); - auto init = [&]() { Create(&(handles_[device_id])); }; - std::call_once(flags[device_id], init); + if (!initialized[device_id]) { + Create(&(handles_[device_id])); + initialized[device_id] = true; + } return handles_[device_id]; }