1818#include " xgboost/string_view.h" // for StringView
1919
2020namespace xgboost ::cudr {
21- CuDriverApi::CuDriverApi () {
21+ CuDriverApi::CuDriverApi (std:: int32_t cu_major, std:: int32_t cu_minor, std:: int32_t kdm_major ) {
2222 // similar to dlopen, but without the need to release a handle.
2323 auto safe_load = [](xgboost::StringView name, auto **fnptr) {
2424 cudaDriverEntryPointQueryResult status;
@@ -41,7 +41,12 @@ CuDriverApi::CuDriverApi() {
4141 safe_load (" cuDeviceGetAttribute" , &this ->cuDeviceGetAttribute );
4242 safe_load (" cuDeviceGet" , &this ->cuDeviceGet );
4343#if defined(CUDA_HW_DECOM_AVAILABLE)
44- safe_load (" cuMemBatchDecompressAsync" , &this ->cuMemBatchDecompressAsync );
44+ // CTK 12.8
45+ if (((cu_major == 12 && cu_minor >= 8 ) || cu_major > 12 ) && (kdm_major >= 570 )) {
46+ safe_load (" cuMemBatchDecompressAsync" , &this ->cuMemBatchDecompressAsync );
47+ } else {
48+ this ->cuMemBatchDecompressAsync = nullptr ;
49+ }
4550#endif // defined(CUDA_HW_DECOM_AVAILABLE)
4651 CHECK (this ->cuMemGetAllocationGranularity );
4752}
@@ -76,9 +81,17 @@ void CuDriverApi::ThrowIfError(CUresult status, StringView fn, std::int32_t line
7681}
7782
7883[[nodiscard]] CuDriverApi &GetGlobalCuDriverApi () {
84+ std::int32_t cu_major = -1 , cu_minor = -1 ;
85+ GetDrVersionGlobal (&cu_major, &cu_minor);
86+
87+ std::int32_t kdm_major = -1 , kdm_minor = -1 ;
88+ if (!GetVersionFromSmiGlobal (&kdm_major, &kdm_minor)) {
89+ kdm_major = -1 ;
90+ }
91+
7992 static std::once_flag flag;
8093 static std::unique_ptr<CuDriverApi> cu;
81- std::call_once (flag, [&] { cu = std::make_unique<CuDriverApi>(); });
94+ std::call_once (flag, [&] { cu = std::make_unique<CuDriverApi>(cu_major, cu_minor, kdm_major ); });
8295 return *cu;
8396}
8497
@@ -154,5 +167,24 @@ void MakeCuMemLocation(CUmemLocationType type, CUmemLocation *loc) {
154167
155168 return Invalid ();
156169}
170+
171+ [[nodiscard]] bool GetVersionFromSmiGlobal (std::int32_t *p_major, std::int32_t *p_minor) {
172+ static std::once_flag flag;
173+ static std::int32_t major = -1 , minor = -1 ;
174+ static bool result = false ;
175+ std::call_once (flag, [&] { result = GetVersionFromSmi (&major, &minor); });
176+
177+ *p_major = major;
178+ *p_minor = minor;
179+ return result;
180+ }
181+
182+ void GetDrVersionGlobal (std::int32_t *p_major, std::int32_t *p_minor) {
183+ static std::once_flag once;
184+ static std::int32_t major{0 }, minor{0 };
185+ std::call_once (once, [] { xgboost::curt::DrVersion (&major, &minor); });
186+ *p_major = major;
187+ *p_minor = minor;
188+ }
157189} // namespace xgboost::cudr
158190#endif
0 commit comments