diff --git a/src/common/device_vector.cuh b/src/common/device_vector.cuh index b8ffda0d10fe..865a153bbe23 100644 --- a/src/common/device_vector.cuh +++ b/src/common/device_vector.cuh @@ -7,10 +7,16 @@ #include // for device_vector #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 -#include // for async_resource_ref -#include // for stream_ref +#include // for async_resource_ref +#include // for stream_ref + +// TODO(hcho3): Remove this guard once we require Rapids 25.12+ +#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 +#include // for get_current_device_resource +#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 #include // for device_memory_resource #include // for get_current_device_resource +#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 #include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore @@ -266,7 +272,13 @@ namespace detail { */ template class ThrustAllocMrAdapter : public thrust::device_malloc_allocator { + +// TODO(hcho3): Remove this guard once we require Rapids 25.12+ +#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 + DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource_ref()}; +#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource()}; +#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 public: using Super = thrust::device_malloc_allocator; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index cd79ca66c589..5d8ca2b58aea 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -29,9 +29,18 @@ #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #include #include + +// TODO(hcho3): Remove this guard once we require Rapids 25.12+ +#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 +#include "rmm/mr/per_device_resource.hpp" +#include "rmm/mr/cuda_memory_resource.hpp" +#include "rmm/mr/pool_memory_resource.hpp" +#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 #include "rmm/mr/device/per_device_resource.hpp" #include "rmm/mr/device/cuda_memory_resource.hpp" #include "rmm/mr/device/pool_memory_resource.hpp" +#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 + #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 bool FileExists(const std::string& filename) {