11/* *
2- * Copyright 2017-2024 , XGBoost Contributors
2+ * Copyright 2017-2025 , XGBoost Contributors
33 */
44#pragma once
55#include < thrust/device_malloc_allocator.h> // for device_malloc_allocator
66#include < thrust/device_ptr.h> // for device_ptr
77#include < thrust/device_vector.h> // for device_vector
88
99#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
10- #include < rmm/device_uvector.hpp> // for device_uvector
11- #include < rmm/exec_policy.hpp> // for exec_policy_nosync
1210#include < rmm/mr/device/device_memory_resource.hpp> // for device_memory_resource
1311#include < rmm/mr/device/per_device_resource.hpp> // for get_current_device_resource
1412#include < rmm/mr/device/thrust_allocator_adaptor.hpp> // for thrust_allocator
3735
3836#include " common.h" // for safe_cuda, HumanMemUnit
3937#include " cuda_dr_utils.h" // for CuDriverApi
38+ #include " cuda_stream.h" // for DefaultStream
4039#include " xgboost/logging.h"
4140#include " xgboost/span.h" // for Span
4241
@@ -383,9 +382,9 @@ using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
383382 * OOM errors.
384383 */
385384template <typename T>
386- using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLINT
385+ using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLINT
387386template <typename T>
388- using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT
387+ using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT
389388
390389#if defined(XGBOOST_USE_RMM)
391390/* *
@@ -450,10 +449,15 @@ template <typename T, bool is_caching>
450449class DeviceUVectorImpl {
451450 private:
452451#if defined(XGBOOST_USE_RMM)
453- rmm::device_uvector<T> data_{ 0 , rmm::cuda_stream_per_thread, GlobalLoggingResource ()};
452+ rmm::device_async_resource_ref mr_{ GlobalLoggingResource ()};
454453#else
455- std::conditional_t <is_caching, ::dh::caching_device_vector<T>, ::dh::device_vector<T>> data_;
456- #endif // defined(XGBOOST_USE_RMM)
454+ using Alloc =
455+ std::conditional_t <is_caching, dh::XGBCachingDeviceAllocator<T>, dh::XGBDeviceAllocator<T>>;
456+ #endif
457+
458+ std::size_t size_{0 };
459+ std::size_t capacity_{0 };
460+ std::unique_ptr<T, std::function<void (T *)>> data_;
457461
458462 public:
459463 using value_type = T; // NOLINT
@@ -470,47 +474,76 @@ class DeviceUVectorImpl {
470474 DeviceUVectorImpl (DeviceUVectorImpl &&that) = default ;
471475 DeviceUVectorImpl &operator =(DeviceUVectorImpl &&that) = default ;
472476
477+ [[nodiscard]] std::size_t Capacity () const { return this ->capacity_ ; }
478+
479+ // Resize without init.
473480 void resize (std::size_t n) { // NOLINT
481+ using ::xgboost::common::SizeBytes;
482+
483+ if (n <= this ->Capacity ()) {
484+ this ->size_ = n;
485+ // early exit as no allocation is needed.
486+ return ;
487+ }
488+ CHECK_LE (this ->size (), this ->Capacity ());
489+ auto s = ::xgboost::curt::DefaultStream ();
490+
491+ decltype (data_) new_ptr{[n, this , s]() {
474492#if defined(XGBOOST_USE_RMM)
475- data_.resize (n, rmm::cuda_stream_per_thread);
493+ auto n_bytes = SizeBytes<T>(n);
494+ auto p = this ->mr_ .allocate_async (n_bytes, rmm::cuda_stream_view{s});
495+ return static_cast <T *>(p);
476496#else
477- data_.resize (n);
497+ auto p = Alloc{}.allocate (n);
498+ return thrust::raw_pointer_cast (p);
478499#endif
479- }
480- void resize (std::size_t n, T const &v) { // NOLINT
500+ }(),
501+ [n, this , s](T *ptr) {
502+ if (ptr) {
481503#if defined(XGBOOST_USE_RMM)
504+ auto n_bytes = SizeBytes<T>(n);
505+ this ->mr_ .deallocate_async (ptr, n_bytes, rmm::cuda_stream_view{s});
506+ #else
507+ Alloc{}.deallocate (thrust::device_pointer_cast (ptr), n);
508+ #endif
509+ }
510+ }};
511+ CHECK (new_ptr.get ());
512+ safe_cuda (cudaMemcpyAsync (new_ptr.get (), this ->data (), SizeBytes<T>(this ->size ()),
513+ cudaMemcpyDefault, s));
514+ this ->size_ = n;
515+ this ->capacity_ = n;
516+
517+ std::swap (this ->data_ , new_ptr);
518+ }
519+ // Resize with init
520+ void resize (std::size_t n, T const &v) { // NOLINT
482521 auto orig = this ->size ();
483- data_. resize (n, rmm::cuda_stream_per_thread );
522+ this -> resize (n);
484523 if (orig < n) {
485- thrust::fill (rmm::exec_policy_nosync{}, this ->begin () + orig, this ->end (), v);
524+ auto exec = thrust::cuda::par_nosync.on (::xgboost::curt::DefaultStream ());
525+ thrust::fill (exec, this ->begin () + orig, this ->end (), v);
486526 }
487- #else
488- data_.resize (n, v);
489- #endif
490527 }
491528
492529 void clear () { // NOLINT
493- #if defined(XGBOOST_USE_RMM)
494- this ->data_ .resize (0 , rmm::cuda_stream_per_thread);
495- #else
496- this ->data_ .clear ();
497- #endif // defined(XGBOOST_USE_RMM)
530+ this ->resize (0 );
498531 }
499532
500- [[nodiscard]] std::size_t size () const { return data_. size () ; } // NOLINT
501- [[nodiscard]] bool empty () const { return this ->size () == 0 ; } // NOLINT
533+ [[nodiscard]] std::size_t size () const { return this -> size_ ; } // NOLINT
534+ [[nodiscard]] bool empty () const { return this ->size () == 0 ; } // NOLINT
502535
503- [[nodiscard]] auto begin () { return data_. begin (); } // NOLINT
504- [[nodiscard]] auto end () { return data_. end (); } // NOLINT
536+ [[nodiscard]] auto begin () { return this -> data (); } // NOLINT
537+ [[nodiscard]] auto end () { return this -> data () + this -> size (); } // NOLINT
505538
506539 [[nodiscard]] auto begin () const { return this ->cbegin (); } // NOLINT
507540 [[nodiscard]] auto end () const { return this ->cend (); } // NOLINT
508541
509- [[nodiscard]] auto cbegin () const { return data_. cbegin (); } // NOLINT
510- [[nodiscard]] auto cend () const { return data_. cend (); } // NOLINT
542+ [[nodiscard]] auto cbegin () const { return this -> data (); } // NOLINT
543+ [[nodiscard]] auto cend () const { return this -> data () + this -> size (); } // NOLINT
511544
512- [[nodiscard]] auto data () { return thrust::raw_pointer_cast ( data_.data () ); } // NOLINT
513- [[nodiscard]] auto data () const { return thrust::raw_pointer_cast ( data_.data () ); } // NOLINT
545+ [[nodiscard]] auto data () { return this -> data_ .get ( ); } // NOLINT
546+ [[nodiscard]] auto data () const { return this -> data_ .get ( ); } // NOLINT
514547};
515548
516549template <typename T>
0 commit comments