diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index 80d2707315..eaf6de680a 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -57,9 +57,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) } } -void nvte_cudnn_handle_init() { - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); -} +void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } + +namespace detail { + +void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); } + +} // namespace detail } // namespace transformer_engine @@ -68,6 +72,6 @@ namespace cudnn_frontend { // This is needed to define the symbol `cudnn_dlhandle` // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING // to enable dynamic loading. -void *cudnn_dlhandle = nullptr; +void* cudnn_dlhandle = nullptr; } // namespace cudnn_frontend diff --git a/transformer_engine/common/cudnn_utils.h b/transformer_engine/common/cudnn_utils.h index eb19b9ddb2..0016ad7f55 100644 --- a/transformer_engine/common/cudnn_utils.h +++ b/transformer_engine/common/cudnn_utils.h @@ -10,37 +10,25 @@ #include #include #include - -#include -#include +#include #include "transformer_engine/transformer_engine.h" +#include "util/handle_manager.h" namespace transformer_engine { -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); +namespace detail { -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); +void CreateCuDNNHandle(cudnnHandle_t* handle); -class cudnnExecutionPlanManager { - public: - static cudnnExecutionPlanManager &Instance() { - static thread_local cudnnExecutionPlanManager instance; - return instance; - } +} // namespace detail - cudnnHandle_t GetCudnnHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { cudnnCreate(&handle_); }); - return handle_; - } +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); - ~cudnnExecutionPlanManager() {} +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); - private: - cudnnHandle_t handle_ = nullptr; -}; +using cudnnExecutionPlanManager = detail::HandleManager; } // namespace transformer_engine -#endif +#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5d3e1d6097..7f871c3d3a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -304,7 +304,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -386,7 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -486,7 +486,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const t_kv = input_KV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -577,7 +577,7 @@ void nvte_fused_attn_bwd_kvpacked( t_kv = input_KV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -674,7 +674,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso t_kv = input_K->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); @@ -761,7 +761,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso t_kv = input_K->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ef7cdc0af9..0a78ded306 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -14,6 +14,7 @@ #include #include "../common.h" +#include "../util/handle_manager.h" #include "../util/logging.h" namespace { @@ -46,10 +47,16 @@ uint32_t _getAlignment(uintptr_t address) { } } +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + } // namespace namespace transformer_engine { +using cublasHandleManager = detail::HandleManager; + void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, @@ -98,8 +105,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, float zero = 0.0; float beta = (accumulate) ? one : zero; - cublasLtHandle_t handle; - NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 89e2e9feec..c9b7e643c6 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -190,7 +190,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); - _handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + _handle = cudnnExecutionPlanManager::Instance().GetHandle(); _graph.set_io_data_type(get_cudnn_fe_dtype(itype)) .set_intermediate_data_type(get_cudnn_fe_dtype(ctype)) diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h new file mode 100644 index 0000000000..adb2f55587 --- /dev/null +++ b/transformer_engine/common/util/handle_manager.h @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ +#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ + +#include + +#include "cuda_runtime.h" +#include "logging.h" + +namespace transformer_engine::detail { + +template +class HandleManager { + public: + static HandleManager& Instance() { + static thread_local HandleManager instance; + return instance; + } + + Handle GetHandle() { + static thread_local std::vector initialized(handles_.size(), false); + const int device_id = cuda::current_device(); + NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); + if (!initialized[device_id]) { + Create(&(handles_[device_id])); + initialized[device_id] = true; + } + return handles_[device_id]; + } + + ~HandleManager() { + if (Destroy != nullptr) { + for (auto& handle : handles_) { + Destroy(handle); + } + } + } + + private: + HandleManager() : handles_(cuda::num_devices(), nullptr) {} + + std::vector handles_ = nullptr; +}; + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_