Skip to content

Commit 1797581

Browse files
committed
Add macros to support building with RMM 25.10
1 parent 1beb1ec commit 1797581

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

src/common/device_vector.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
1010
#include <cuda/memory_resource> // for async_resource_ref
1111
#include <cuda/stream_ref> // for stream_ref
12+
13+
// TODO(hcho3): Remove this guard once we require Rapids 25.12+
14+
#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
1215
#include <rmm/mr/per_device_resource.hpp> // for get_current_device_resource
16+
#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
17+
#include <rmm/mr/device/device_memory_resource.hpp> // for device_memory_resource
18+
#include <rmm/mr/device/per_device_resource.hpp> // for get_current_device_resource
19+
#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
1320

1421
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
1522

@@ -265,7 +272,13 @@ namespace detail {
265272
*/
266273
template <typename T>
267274
class ThrustAllocMrAdapter : public thrust::device_malloc_allocator<T> {
275+
276+
// TODO(hcho3): Remove this guard once we require Rapids 25.12+
277+
#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
268278
DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource_ref()};
279+
#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
280+
DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource()};
281+
#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
269282

270283
public:
271284
using Super = thrust::device_malloc_allocator<T>;

tests/cpp/helpers.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,18 @@
2929
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
3030
#include <memory>
3131
#include <vector>
32+
33+
// TODO(hcho3): Remove this guard once we require Rapids 25.12+
34+
#if (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
3235
#include "rmm/mr/per_device_resource.hpp"
3336
#include "rmm/mr/cuda_memory_resource.hpp"
3437
#include "rmm/mr/pool_memory_resource.hpp"
38+
#else // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
39+
#include "rmm/mr/device/per_device_resource.hpp"
40+
#include "rmm/mr/device/cuda_memory_resource.hpp"
41+
#include "rmm/mr/device/pool_memory_resource.hpp"
42+
#endif // (RMM_MAJOR_VERSION == 25 && RMM_MINOR_VERSION == 12) || RMM_MAJOR_VERSION >= 26
43+
3544
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
3645

3746
bool FileExists(const std::string& filename) {

0 commit comments

Comments
 (0)