44#ifndef XGBOOST_COMMON_ALGORITHM_CUH_
55#define XGBOOST_COMMON_ALGORITHM_CUH_
66
7- #include < thrust/copy.h> // for copy
8- #include < thrust/iterator/counting_iterator.h> // for make_counting_iterator
9- #include < thrust/sort.h> // for stable_sort_by_key
10- #include < thrust/tuple.h> // for tuple, get
7+ #include < thrust/copy.h> // for copy
8+ #include < thrust/iterator/counting_iterator.h> // for make_counting_iterator
9+ #include < thrust/sort.h> // for stable_sort_by_key
10+ #include < thrust/tuple.h> // for tuple, get
1111
1212#include < cstddef> // size_t
1313#include < cstdint> // int32_t
1818
1919#include " common.h" // safe_cuda
2020#include " cuda_context.cuh" // CUDAContext
21+ #include " cuda_stream.h" // for StreamRef
2122#include " device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota
2223#include " device_vector.cuh" // for device_vector
2324#include " xgboost/base.h" // XGBOOST_DEVICE
2425#include " xgboost/context.h" // Context
25- #include " xgboost/linalg.h" // for VectorView
2626#include " xgboost/logging.h" // CHECK
2727#include " xgboost/span.h" // Span,byte
2828
2929namespace xgboost ::common {
3030namespace detail {
3131
3232#if CUB_VERSION >= 300000
33- constexpr auto kCubSortOrderAscending = cub::SortOrder::Ascending;
34- constexpr auto kCubSortOrderDescending = cub::SortOrder::Descending;
33+ constexpr auto kCubSortOrderAscending = cub::SortOrder::Ascending;
34+ constexpr auto kCubSortOrderDescending = cub::SortOrder::Descending;
3535#else
36- constexpr bool kCubSortOrderAscending = false ;
37- constexpr bool kCubSortOrderDescending = true ;
36+ constexpr bool kCubSortOrderAscending = false ;
37+ constexpr bool kCubSortOrderDescending = true ;
3838#endif
3939
4040// Wrapper around cub sort to define is_decending
@@ -70,7 +70,7 @@ void DeviceSegmentedRadixSortPair(void *d_temp_storage,
7070 const ValueT *d_values_in, ValueT *d_values_out,
7171 std::size_t num_items, std::size_t num_segments,
7272 BeginOffsetIteratorT d_begin_offsets,
73- EndOffsetIteratorT d_end_offsets, dh::CUDAStreamView stream,
73+ EndOffsetIteratorT d_end_offsets, curt::StreamRef stream,
7474 int begin_bit = 0 , int end_bit = sizeof (KeyT) * 8) {
7575 cub::DoubleBuffer<KeyT> d_keys (const_cast <KeyT *>(d_keys_in), d_keys_out);
7676 cub::DoubleBuffer<ValueT> d_values (const_cast <ValueT *>(d_values_in), d_values_out);
@@ -198,7 +198,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
198198 if (thrust::get<0 >(l) != thrust::get<0 >(r)) {
199199 return thrust::get<0 >(l) < thrust::get<0 >(r); // segment index
200200 }
201- return thrust::get<1 >(l) < thrust::get<1 >(r); // residue
201+ return thrust::get<1 >(l) < thrust::get<1 >(r); // residue
202202 });
203203}
204204
@@ -224,46 +224,54 @@ void ArgSort(Context const *ctx, Span<U> keys, Span<IdxT> sorted_idx) {
224224 if (accending) {
225225 void *d_temp_storage = nullptr ;
226226#if THRUST_MAJOR_VERSION >= 2
227- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
228- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
229- cuctx->Stream ())));
227+ dh::safe_cuda (
228+ (cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
229+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
230+ cuctx->Stream ())));
230231#else
231- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
232- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
233- nullptr , false )));
232+ dh::safe_cuda (
233+ (cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
234+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
235+ nullptr , false )));
234236#endif
235237 dh::TemporaryArray<char > storage (bytes);
236238 d_temp_storage = storage.data ().get ();
237239#if THRUST_MAJOR_VERSION >= 2
238- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
239- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
240- cuctx->Stream ())));
240+ dh::safe_cuda (
241+ (cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
242+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
243+ cuctx->Stream ())));
241244#else
242- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
243- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
244- nullptr , false )));
245+ dh::safe_cuda (
246+ (cub::DispatchRadixSort<detail::kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
247+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
248+ nullptr , false )));
245249#endif
246250 } else {
247251 void *d_temp_storage = nullptr ;
248252#if THRUST_MAJOR_VERSION >= 2
249- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
250- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
251- cuctx->Stream ())));
253+ dh::safe_cuda (
254+ (cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
255+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
256+ cuctx->Stream ())));
252257#else
253- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
254- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
255- nullptr , false )));
258+ dh::safe_cuda (
259+ (cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
260+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
261+ nullptr , false )));
256262#endif
257263 dh::TemporaryArray<char > storage (bytes);
258264 d_temp_storage = storage.data ().get ();
259265#if THRUST_MAJOR_VERSION >= 2
260- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
261- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
262- cuctx->Stream ())));
266+ dh::safe_cuda (
267+ (cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
268+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
269+ cuctx->Stream ())));
263270#else
264- dh::safe_cuda ((cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
265- d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
266- nullptr , false )));
271+ dh::safe_cuda (
272+ (cub::DispatchRadixSort<detail::kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
273+ d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
274+ nullptr , false )));
267275#endif
268276 }
269277
@@ -330,15 +338,15 @@ void InclusiveSum(Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out
330338}
331339
332340template <typename ... Args>
333- void RunLengthEncode (dh::CUDAStreamView stream, Args &&...args) {
341+ void RunLengthEncode (curt::StreamRef stream, Args &&...args) {
334342 std::size_t n_bytes = 0 ;
335343 dh::safe_cuda (cub::DeviceRunLengthEncode::Encode (nullptr , n_bytes, args..., stream));
336344 dh::CachingDeviceUVector<char > tmp (n_bytes);
337345 dh::safe_cuda (cub::DeviceRunLengthEncode::Encode (tmp.data (), n_bytes, args..., stream));
338346}
339347
340348template <typename ... Args>
341- void SegmentedSum (dh::CUDAStreamView stream, Args &&...args) {
349+ void SegmentedSum (curt::StreamRef stream, Args &&...args) {
342350 std::size_t n_bytes = 0 ;
343351 dh::safe_cuda (cub::DeviceSegmentedReduce::Sum (nullptr , n_bytes, args..., stream));
344352 dh::CachingDeviceUVector<char > tmp (n_bytes);
0 commit comments