Skip to content

Commit ef312f0

Browse files
authored
Implement the device uvector. (#11715)
1 parent 470573a commit ef312f0

File tree

4 files changed

+99
-38
lines changed

4 files changed

+99
-38
lines changed

include/xgboost/span.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2018-2024, XGBoost contributors
2+
* Copyright 2018-2025, XGBoost contributors
33
* \brief span class based on ISO++20 span
44
*
55
* About NOLINTs in this file:
@@ -358,6 +358,11 @@ XGBOOST_DEVICE bool LexicographicalCompare(InputIt1 first1, InputIt1 last1,
358358

359359
} // namespace detail
360360

361+
template <typename T>
362+
XGBOOST_DEVICE std::enable_if_t<!std::is_reference_v<T> && !std::is_pointer_v<T>, std::size_t>
363+
SizeBytes(std::size_t n) {
364+
return n * sizeof(T);
365+
}
361366

362367
/*!
363368
* \brief span class implementation, based on ISO++20 span<T>. The interface
@@ -556,7 +561,7 @@ class Span {
556561
return size_;
557562
}
558563
XGBOOST_DEVICE constexpr index_type size_bytes() const __span_noexcept { // NOLINT
559-
return size() * sizeof(T);
564+
return SizeBytes<T>(size());
560565
}
561566

562567
XGBOOST_DEVICE constexpr bool empty() const __span_noexcept { // NOLINT

src/common/device_vector.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ GrowOnlyVirtualMemVec::GrowOnlyVirtualMemVec(CUmemLocationType type)
9999

100100
#if defined(XGBOOST_USE_RMM)
101101
LoggingResource *GlobalLoggingResource() {
102-
static auto mr{std::make_unique<LoggingResource>()};
102+
static std::unique_ptr<LoggingResource> mr;
103+
static std::once_flag flag;
104+
std::call_once(flag, [&] { mr = std::make_unique<LoggingResource>(); });
103105
return mr.get();
104106
}
105107
#endif // defined(XGBOOST_USE_RMM)

src/common/device_vector.cuh

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
/**
2-
* Copyright 2017-2024, XGBoost Contributors
2+
* Copyright 2017-2025, XGBoost Contributors
33
*/
44
#pragma once
55
#include <thrust/device_malloc_allocator.h> // for device_malloc_allocator
66
#include <thrust/device_ptr.h> // for device_ptr
77
#include <thrust/device_vector.h> // for device_vector
88

99
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
10-
#include <rmm/device_uvector.hpp> // for device_uvector
11-
#include <rmm/exec_policy.hpp> // for exec_policy_nosync
1210
#include <rmm/mr/device/device_memory_resource.hpp> // for device_memory_resource
1311
#include <rmm/mr/device/per_device_resource.hpp> // for get_current_device_resource
1412
#include <rmm/mr/device/thrust_allocator_adaptor.hpp> // for thrust_allocator
@@ -37,6 +35,7 @@
3735

3836
#include "common.h" // for safe_cuda, HumanMemUnit
3937
#include "cuda_dr_utils.h" // for CuDriverApi
38+
#include "cuda_stream.h" // for DefaultStream
4039
#include "xgboost/logging.h"
4140
#include "xgboost/span.h" // for Span
4241

@@ -383,9 +382,9 @@ using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
383382
* OOM errors.
384383
*/
385384
template <typename T>
386-
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLINT
385+
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLINT
387386
template <typename T>
388-
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT
387+
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT
389388

390389
#if defined(XGBOOST_USE_RMM)
391390
/**
@@ -450,10 +449,15 @@ template <typename T, bool is_caching>
450449
class DeviceUVectorImpl {
451450
private:
452451
#if defined(XGBOOST_USE_RMM)
453-
rmm::device_uvector<T> data_{0, rmm::cuda_stream_per_thread, GlobalLoggingResource()};
452+
rmm::device_async_resource_ref mr_{GlobalLoggingResource()};
454453
#else
455-
std::conditional_t<is_caching, ::dh::caching_device_vector<T>, ::dh::device_vector<T>> data_;
456-
#endif // defined(XGBOOST_USE_RMM)
454+
using Alloc =
455+
std::conditional_t<is_caching, dh::XGBCachingDeviceAllocator<T>, dh::XGBDeviceAllocator<T>>;
456+
#endif
457+
458+
std::size_t size_{0};
459+
std::size_t capacity_{0};
460+
std::unique_ptr<T, std::function<void(T *)>> data_;
457461

458462
public:
459463
using value_type = T; // NOLINT
@@ -470,47 +474,76 @@ class DeviceUVectorImpl {
470474
DeviceUVectorImpl(DeviceUVectorImpl &&that) = default;
471475
DeviceUVectorImpl &operator=(DeviceUVectorImpl &&that) = default;
472476

477+
[[nodiscard]] std::size_t Capacity() const { return this->capacity_; }
478+
479+
// Resize without init.
473480
void resize(std::size_t n) { // NOLINT
481+
using ::xgboost::common::SizeBytes;
482+
483+
if (n <= this->Capacity()) {
484+
this->size_ = n;
485+
// early exit as no allocation is needed.
486+
return;
487+
}
488+
CHECK_LE(this->size(), this->Capacity());
489+
auto s = ::xgboost::curt::DefaultStream();
490+
491+
decltype(data_) new_ptr{[n, this, s]() {
474492
#if defined(XGBOOST_USE_RMM)
475-
data_.resize(n, rmm::cuda_stream_per_thread);
493+
auto n_bytes = SizeBytes<T>(n);
494+
auto p = this->mr_.allocate_async(n_bytes, rmm::cuda_stream_view{s});
495+
return static_cast<T *>(p);
476496
#else
477-
data_.resize(n);
497+
auto p = Alloc{}.allocate(n);
498+
return thrust::raw_pointer_cast(p);
478499
#endif
479-
}
480-
void resize(std::size_t n, T const &v) { // NOLINT
500+
}(),
501+
[n, this, s](T *ptr) {
502+
if (ptr) {
481503
#if defined(XGBOOST_USE_RMM)
504+
auto n_bytes = SizeBytes<T>(n);
505+
this->mr_.deallocate_async(ptr, n_bytes, rmm::cuda_stream_view{s});
506+
#else
507+
Alloc{}.deallocate(thrust::device_pointer_cast(ptr), n);
508+
#endif
509+
}
510+
}};
511+
CHECK(new_ptr.get());
512+
safe_cuda(cudaMemcpyAsync(new_ptr.get(), this->data(), SizeBytes<T>(this->size()),
513+
cudaMemcpyDefault, s));
514+
this->size_ = n;
515+
this->capacity_ = n;
516+
517+
std::swap(this->data_, new_ptr);
518+
}
519+
// Resize with init
520+
void resize(std::size_t n, T const &v) { // NOLINT
482521
auto orig = this->size();
483-
data_.resize(n, rmm::cuda_stream_per_thread);
522+
this->resize(n);
484523
if (orig < n) {
485-
thrust::fill(rmm::exec_policy_nosync{}, this->begin() + orig, this->end(), v);
524+
auto exec = thrust::cuda::par_nosync.on(::xgboost::curt::DefaultStream());
525+
thrust::fill(exec, this->begin() + orig, this->end(), v);
486526
}
487-
#else
488-
data_.resize(n, v);
489-
#endif
490527
}
491528

492529
void clear() { // NOLINT
493-
#if defined(XGBOOST_USE_RMM)
494-
this->data_.resize(0, rmm::cuda_stream_per_thread);
495-
#else
496-
this->data_.clear();
497-
#endif // defined(XGBOOST_USE_RMM)
530+
this->resize(0);
498531
}
499532

500-
[[nodiscard]] std::size_t size() const { return data_.size(); } // NOLINT
501-
[[nodiscard]] bool empty() const { return this->size() == 0; } // NOLINT
533+
[[nodiscard]] std::size_t size() const { return this->size_; } // NOLINT
534+
[[nodiscard]] bool empty() const { return this->size() == 0; } // NOLINT
502535

503-
[[nodiscard]] auto begin() { return data_.begin(); } // NOLINT
504-
[[nodiscard]] auto end() { return data_.end(); } // NOLINT
536+
[[nodiscard]] auto begin() { return this->data(); } // NOLINT
537+
[[nodiscard]] auto end() { return this->data() + this->size(); } // NOLINT
505538

506539
[[nodiscard]] auto begin() const { return this->cbegin(); } // NOLINT
507540
[[nodiscard]] auto end() const { return this->cend(); } // NOLINT
508541

509-
[[nodiscard]] auto cbegin() const { return data_.cbegin(); } // NOLINT
510-
[[nodiscard]] auto cend() const { return data_.cend(); } // NOLINT
542+
[[nodiscard]] auto cbegin() const { return this->data(); } // NOLINT
543+
[[nodiscard]] auto cend() const { return this->data() + this->size(); } // NOLINT
511544

512-
[[nodiscard]] auto data() { return thrust::raw_pointer_cast(data_.data()); } // NOLINT
513-
[[nodiscard]] auto data() const { return thrust::raw_pointer_cast(data_.data()); } // NOLINT
545+
[[nodiscard]] auto data() { return this->data_.get(); } // NOLINT
546+
[[nodiscard]] auto data() const { return this->data_.get(); } // NOLINT
514547
};
515548

516549
template <typename T>

tests/cpp/common/test_device_vector.cu

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
/**
2-
* Copyright 2024, XGBoost Contributors
2+
* Copyright 2024-2025, XGBoost Contributors
33
*/
44
#include <gtest/gtest.h>
5-
#include <thread> // for thread
5+
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
6+
#include <thrust/sequence.h> // for sequence
67

7-
#include <numeric> // for iota
8-
#include <thrust/detail/sequence.inl> // for sequence
8+
#include <numeric> // for iota
9+
#include <thread> // for thread
910

1011
#include "../../../src/common/cuda_rt_utils.h" // for DrVersion
1112
#include "../../../src/common/device_helpers.cuh" // for CachingThrustPolicy, PinnedMemory
1213
#include "../../../src/common/device_vector.cuh"
1314
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
14-
#include "xgboost/windefs.h" // for xgboost_IS_WIN
15+
#include "xgboost/windefs.h" // for xgboost_IS_WIN
1516

1617
namespace dh {
1718
TEST(DeviceUVector, Basic) {
@@ -24,6 +25,26 @@ TEST(DeviceUVector, Basic) {
2425
auto n_bytes = sizeof(decltype(uvec)::value_type) * uvec.size();
2526
ASSERT_EQ(peak, n_bytes);
2627
std::swap(verbosity, xgboost::GlobalConfigThreadLocalStore::Get()->verbosity);
28+
29+
DeviceUVector<double> uvec1{16};
30+
ASSERT_EQ(uvec1.size(), 16);
31+
uvec1.resize(3);
32+
ASSERT_EQ(uvec1.size(), 3);
33+
ASSERT_EQ(uvec1.Capacity(), 16);
34+
ASSERT_EQ(std::distance(uvec1.begin(), uvec1.end()), uvec1.size());
35+
auto orig = uvec1.size();
36+
37+
thrust::sequence(dh::CachingThrustPolicy(), uvec1.begin(), uvec1.end(), 0);
38+
uvec1.resize(32);
39+
ASSERT_EQ(uvec1.size(), 32);
40+
ASSERT_EQ(uvec1.Capacity(), 32);
41+
auto eq = thrust::equal(dh::CachingThrustPolicy(), uvec1.cbegin(), uvec1.cbegin() + orig,
42+
thrust::make_counting_iterator(0));
43+
ASSERT_TRUE(eq);
44+
45+
uvec1.clear();
46+
ASSERT_EQ(uvec1.size(), 0);
47+
ASSERT_EQ(uvec1.Capacity(), 32);
2748
}
2849

2950
#if defined(__linux__)

0 commit comments

Comments
 (0)