|
7 | 7 | #include <thrust/device_vector.h> // for device_vector |
8 | 8 |
|
9 | 9 | #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 |
10 | | -#include <cuda/memory_resource> // for async_resource_ref |
11 | | -#include <cuda/stream_ref> // for stream_ref |
| 10 | +#include <cuda/memory_resource> // for async_resource_ref |
| 11 | +#include <cuda/stream_ref> // for stream_ref |
| 12 | + |
| 13 | +// TODO(hcho3): Remove this guard once we require Rapids 25.12+ |
| 14 | +#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 |
| 15 | +#include <rmm/mr/per_device_resource.hpp> // for get_current_device_resource |
| 16 | +#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 |
12 | 17 | #include <rmm/mr/device/device_memory_resource.hpp> // for device_memory_resource |
13 | 18 | #include <rmm/mr/device/per_device_resource.hpp> // for get_current_device_resource |
| 19 | +#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 |
14 | 20 |
|
15 | 21 | #include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore |
16 | 22 |
|
@@ -266,7 +272,13 @@ namespace detail { |
266 | 272 | */ |
267 | 273 | template <typename T> |
268 | 274 | class ThrustAllocMrAdapter : public thrust::device_malloc_allocator<T> { |
| 275 | + |
| 276 | +// TODO(hcho3): Remove this guard once we require Rapids 25.12+ |
| 277 | +#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 |
| 278 | + DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource_ref()}; |
| 279 | +#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 |
269 | 280 | DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource()}; |
| 281 | +#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26 |
270 | 282 |
|
271 | 283 | public: |
272 | 284 | using Super = thrust::device_malloc_allocator<T>; |
|
0 commit comments