Skip to content

Commit 470573a

Browse files
authored
Extract the stream, event, and stream view. (#11706)
1 parent a3a5786 commit 470573a

31 files changed

+276
-240
lines changed

plugin/federated/federated_comm.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2025, XGBoost Contributors
33
*/
44
#pragma once
55

66
#include <memory> // for shared_ptr
77

88
#include "../../src/collective/coll.h" // for Coll
9-
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
9+
#include "../../src/common/cuda_stream.h" // for StreamRef
1010
#include "federated_comm.h" // for FederatedComm
1111
#include "xgboost/context.h" // for Context
1212

1313
namespace xgboost::collective {
1414
class CUDAFederatedComm : public FederatedComm {
15-
dh::CUDAStreamView stream_;
15+
curt::StreamRef stream_;
1616

1717
public:
1818
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);

src/collective/coll.cu

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
#include <type_traits> // for invoke_result_t, is_same_v, enable_if_t
1414
#include <utility> // for move
1515

16-
#include "../common/device_helpers.cuh" // for CUDAStreamView, CUDAEvent, device_vector
16+
#include "../common/cuda_stream.h" // for StreamRef, Event
17+
#include "../common/device_helpers.cuh" // for device_vector
1718
#include "../common/threadpool.h" // for ThreadPool
1819
#include "../common/utils.h" // for MakeCleanup
1920
#include "../data/array_interface.h" // for ArrayInterfaceHandler
@@ -87,16 +88,16 @@ struct Chan {
8788
};
8889
} // namespace
8990

90-
template <typename Fn, typename R = std::invoke_result_t<Fn, dh::CUDAStreamView>>
91+
template <typename Fn, typename R = std::invoke_result_t<Fn, curt::StreamRef>>
9192
[[nodiscard]] std::enable_if_t<std::is_same_v<R, Result>, Result> AsyncLaunch(
9293
common::ThreadPool* pool, NCCLComm const* nccl, std::shared_ptr<NcclStub> stub,
93-
dh::CUDAStreamView stream, Fn&& fn) {
94-
dh::CUDAEvent e0;
94+
curt::StreamRef stream, Fn&& fn) {
95+
curt::Event e0;
9596
e0.Record(nccl->Stream());
9697
stream.Wait(e0);
9798

9899
auto cleanup = common::MakeCleanup([&] {
99-
dh::CUDAEvent e1;
100+
curt::Event e1;
100101
e1.Record(stream);
101102
nccl->Stream().Wait(e1);
102103
});
@@ -180,7 +181,7 @@ bool IsBitwiseOp(Op const& op) {
180181
}
181182

182183
template <typename Func>
183-
void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> out_buffer,
184+
void RunBitwiseAllreduce(curt::StreamRef stream, common::Span<std::int8_t> out_buffer,
184185
std::int8_t const* device_buffer, Func func, std::int32_t world_size,
185186
std::size_t size) {
186187
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
@@ -194,13 +195,13 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> ou
194195

195196
[[nodiscard]] Result BitwiseAllReduce(common::ThreadPool* pool, NCCLComm const* pcomm,
196197
common::Span<std::int8_t> data, Op op,
197-
dh::CUDAStreamView stream) {
198+
curt::StreamRef stream) {
198199
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
199200
auto* device_buffer = buffer.data().get();
200201
auto stub = pcomm->Stub();
201202

202203
// First gather data from all the workers.
203-
auto rc = AsyncLaunch(pool, pcomm, stub, stream, [&](dh::CUDAStreamView s) {
204+
auto rc = AsyncLaunch(pool, pcomm, stub, stream, [&](curt::StreamRef s) {
204205
return stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, pcomm->Handle(), s);
205206
});
206207
if (!rc.OK()) {
@@ -263,7 +264,7 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
263264
using T = decltype(t);
264265
auto rdata = common::RestoreType<T>(data);
265266
return AsyncLaunch(
266-
&this->pool_, nccl, stub, this->stream_.View(), [&](dh::CUDAStreamView s) {
267+
&this->pool_, nccl, stub, this->stream_.View(), [&](curt::StreamRef s) {
267268
return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
268269
GetNCCLRedOp(op), nccl->Handle(), s);
269270
});
@@ -285,7 +286,7 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
285286

286287
return Success() << [&] {
287288
return AsyncLaunch(&this->pool_, nccl, stub, this->stream_.View(),
288-
[data, nccl, root, stub](dh::CUDAStreamView s) {
289+
[data, nccl, root, stub](curt::StreamRef s) {
289290
return stub->Broadcast(data.data(), data.data(), data.size_bytes(),
290291
ncclInt8, root, nccl->Handle(), s);
291292
});
@@ -306,7 +307,7 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
306307
auto send = data.subspan(comm.Rank() * size, size);
307308
return Success() << [&] {
308309
return AsyncLaunch(&this->pool_, nccl, stub, this->stream_.View(),
309-
[send, data, size, nccl, stub](dh::CUDAStreamView s) {
310+
[send, data, size, nccl, stub](curt::StreamRef s) {
310311
return stub->Allgather(send.data(), data.data(), size, ncclInt8,
311312
nccl->Handle(), s);
312313
});
@@ -321,7 +322,7 @@ namespace cuda_impl {
321322
*
322323
* https://arxiv.org/abs/1812.05964
323324
*/
324-
Result BroadcastAllgatherV(NCCLComm const* comm, dh::CUDAStreamView s,
325+
Result BroadcastAllgatherV(NCCLComm const* comm, curt::StreamRef s,
325326
common::Span<std::int8_t const> data,
326327
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
327328
auto stub = comm->Stub();
@@ -379,7 +380,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, dh::CUDAStreamView s,
379380
};
380381
}
381382
case AllgatherVAlgo::kBcast: {
382-
return AsyncLaunch(&this->pool_, nccl, stub, this->stream_.View(), [&](dh::CUDAStreamView s) {
383+
return AsyncLaunch(&this->pool_, nccl, stub, this->stream_.View(), [&](curt::StreamRef s) {
383384
return cuda_impl::BroadcastAllgatherV(nccl, s, data, sizes, recv);
384385
});
385386
}

src/collective/coll.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2025, XGBoost Contributors
33
*/
44
#pragma once
55

66
#include <cstdint> // for int8_t, int64_t
77

8-
#include "../common/device_helpers.cuh" // for CUDAStream
9-
#include "../common/threadpool.h" // for ThreadPool
10-
#include "../data/array_interface.h" // for ArrayInterfaceHandler
11-
#include "coll.h" // for Coll
12-
#include "comm.h" // for Comm
13-
#include "xgboost/span.h" // for Span
8+
#include "../common/cuda_stream.h" // for Stream
9+
#include "../common/threadpool.h" // for ThreadPool
10+
#include "../data/array_interface.h" // for ArrayInterfaceHandler
11+
#include "coll.h" // for Coll
12+
#include "comm.h" // for Comm
13+
#include "xgboost/span.h" // for Span
1414

1515
namespace xgboost::collective {
1616
class NCCLColl : public Coll {
1717
common::ThreadPool pool_;
18-
dh::CUDAStream stream_;
18+
curt::Stream stream_;
1919

2020
public:
2121
NCCLColl();

src/collective/comm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2025, XGBoost Contributors
33
*/
44
#if defined(XGBOOST_USE_NCCL)
55
#include <algorithm> // for sort
@@ -113,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
113113

114114
for (std::int32_t r = 0; r < root.World(); ++r) {
115115
this->channels_.emplace_back(
116-
std::make_shared<NCCLChannel>(root, r, nccl_comm_, stub_, dh::DefaultStream()));
116+
std::make_shared<NCCLChannel>(root, r, nccl_comm_, stub_, curt::DefaultStream()));
117117
}
118118
}
119119

src/collective/comm.cuh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023, XGBoost Contributors
2+
* Copyright 2023-2025, XGBoost Contributors
33
*/
44
#pragma once
55

@@ -9,7 +9,7 @@
99

1010
#include <utility> // for move
1111

12-
#include "../common/device_helpers.cuh"
12+
#include "../common/cuda_stream.h" // for StreamRef
1313
#include "coll.h"
1414
#include "comm.h"
1515
#include "nccl_stub.h" // for NcclStub
@@ -30,7 +30,7 @@ class NCCLComm : public Comm {
3030
ncclComm_t nccl_comm_{nullptr};
3131
std::shared_ptr<NcclStub> stub_;
3232
ncclUniqueId nccl_unique_id_{};
33-
dh::CUDAStreamView stream_;
33+
curt::StreamRef stream_;
3434
std::string nccl_path_;
3535

3636
public:
@@ -45,7 +45,7 @@ class NCCLComm : public Comm {
4545
}
4646
~NCCLComm() override;
4747
[[nodiscard]] bool IsFederated() const override { return false; }
48-
[[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; }
48+
[[nodiscard]] curt::StreamRef Stream() const { return stream_; }
4949
[[nodiscard]] Result Block() const override {
5050
auto rc = this->Stream().Sync(false);
5151
return GetCUDAResult(rc);
@@ -60,16 +60,16 @@ class NCCLChannel : public Channel {
6060
std::int32_t rank_{-1};
6161
ncclComm_t nccl_comm_{};
6262
std::shared_ptr<NcclStub> stub_;
63-
dh::CUDAStreamView stream_;
63+
curt::StreamRef stream_;
6464

6565
public:
6666
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
67-
std::shared_ptr<NcclStub> stub, dh::CUDAStreamView stream)
67+
std::shared_ptr<NcclStub> stub, curt::StreamRef stream)
6868
: rank_{rank},
6969
nccl_comm_{nccl_comm},
7070
stub_{std::move(stub)},
7171
Channel{comm, nullptr},
72-
stream_{stream} {}
72+
stream_{std::move(stream)} {}
7373

7474
[[nodiscard]] Result SendAll(std::int8_t const* ptr, std::size_t n) override {
7575
return stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);

src/common/algorithm.cuh

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#ifndef XGBOOST_COMMON_ALGORITHM_CUH_
55
#define XGBOOST_COMMON_ALGORITHM_CUH_
66

7-
#include <thrust/copy.h> // for copy
8-
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
9-
#include <thrust/sort.h> // for stable_sort_by_key
10-
#include <thrust/tuple.h> // for tuple, get
7+
#include <thrust/copy.h> // for copy
8+
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
9+
#include <thrust/sort.h> // for stable_sort_by_key
10+
#include <thrust/tuple.h> // for tuple, get
1111

1212
#include <cstddef> // size_t
1313
#include <cstdint> // int32_t
@@ -18,23 +18,23 @@
1818

1919
#include "common.h" // safe_cuda
2020
#include "cuda_context.cuh" // CUDAContext
21+
#include "cuda_stream.h" // for StreamRef
2122
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota
2223
#include "device_vector.cuh" // for device_vector
2324
#include "xgboost/base.h" // XGBOOST_DEVICE
2425
#include "xgboost/context.h" // Context
25-
#include "xgboost/linalg.h" // for VectorView
2626
#include "xgboost/logging.h" // CHECK
2727
#include "xgboost/span.h" // Span,byte
2828

2929
namespace xgboost::common {
3030
namespace detail {
3131

3232
#if CUB_VERSION >= 300000
33-
constexpr auto kCubSortOrderAscending = cub::SortOrder::Ascending;
34-
constexpr auto kCubSortOrderDescending = cub::SortOrder::Descending;
33+
constexpr auto kCubSortOrderAscending = cub::SortOrder::Ascending;
34+
constexpr auto kCubSortOrderDescending = cub::SortOrder::Descending;
3535
#else
36-
constexpr bool kCubSortOrderAscending = false;
37-
constexpr bool kCubSortOrderDescending = true;
36+
constexpr bool kCubSortOrderAscending = false;
37+
constexpr bool kCubSortOrderDescending = true;
3838
#endif
3939

4040
// Wrapper around cub sort to define is_decending
@@ -70,7 +70,7 @@ void DeviceSegmentedRadixSortPair(void *d_temp_storage,
7070
const ValueT *d_values_in, ValueT *d_values_out,
7171
std::size_t num_items, std::size_t num_segments,
7272
BeginOffsetIteratorT d_begin_offsets,
73-
EndOffsetIteratorT d_end_offsets, dh::CUDAStreamView stream,
73+
EndOffsetIteratorT d_end_offsets, curt::StreamRef stream,
7474
int begin_bit = 0, int end_bit = sizeof(KeyT) * 8) {
7575
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
7676
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in), d_values_out);
@@ -198,7 +198,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
198198
if (thrust::get<0>(l) != thrust::get<0>(r)) {
199199
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
200200
}
201-
return thrust::get<1>(l) < thrust::get<1>(r); // residue
201+
return thrust::get<1>(l) < thrust::get<1>(r); // residue
202202
});
203203
}
204204

@@ -224,46 +224,54 @@ void ArgSort(Context const *ctx, Span<U> keys, Span<IdxT> sorted_idx) {
224224
if (accending) {
225225
void *d_temp_storage = nullptr;
226226
#if THRUST_MAJOR_VERSION >= 2
227-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
228-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
229-
cuctx->Stream())));
227+
dh::safe_cuda(
228+
(cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
229+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
230+
cuctx->Stream())));
230231
#else
231-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
232-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
233-
nullptr, false)));
232+
dh::safe_cuda(
233+
(cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
234+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
235+
nullptr, false)));
234236
#endif
235237
dh::TemporaryArray<char> storage(bytes);
236238
d_temp_storage = storage.data().get();
237239
#if THRUST_MAJOR_VERSION >= 2
238-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
239-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
240-
cuctx->Stream())));
240+
dh::safe_cuda(
241+
(cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
242+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
243+
cuctx->Stream())));
241244
#else
242-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
243-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
244-
nullptr, false)));
245+
dh::safe_cuda(
246+
(cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
247+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
248+
nullptr, false)));
245249
#endif
246250
} else {
247251
void *d_temp_storage = nullptr;
248252
#if THRUST_MAJOR_VERSION >= 2
249-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
250-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
251-
cuctx->Stream())));
253+
dh::safe_cuda(
254+
(cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
255+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
256+
cuctx->Stream())));
252257
#else
253-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
254-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
255-
nullptr, false)));
258+
dh::safe_cuda(
259+
(cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
260+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
261+
nullptr, false)));
256262
#endif
257263
dh::TemporaryArray<char> storage(bytes);
258264
d_temp_storage = storage.data().get();
259265
#if THRUST_MAJOR_VERSION >= 2
260-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
261-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
262-
cuctx->Stream())));
266+
dh::safe_cuda(
267+
(cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
268+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
269+
cuctx->Stream())));
263270
#else
264-
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
265-
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
266-
nullptr, false)));
271+
dh::safe_cuda(
272+
(cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
273+
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
274+
nullptr, false)));
267275
#endif
268276
}
269277

@@ -330,15 +338,15 @@ void InclusiveSum(Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out
330338
}
331339

332340
template <typename... Args>
333-
void RunLengthEncode(dh::CUDAStreamView stream, Args &&...args) {
341+
void RunLengthEncode(curt::StreamRef stream, Args &&...args) {
334342
std::size_t n_bytes = 0;
335343
dh::safe_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, n_bytes, args..., stream));
336344
dh::CachingDeviceUVector<char> tmp(n_bytes);
337345
dh::safe_cuda(cub::DeviceRunLengthEncode::Encode(tmp.data(), n_bytes, args..., stream));
338346
}
339347

340348
template <typename... Args>
341-
void SegmentedSum(dh::CUDAStreamView stream, Args &&...args) {
349+
void SegmentedSum(curt::StreamRef stream, Args &&...args) {
342350
std::size_t n_bytes = 0;
343351
dh::safe_cuda(cub::DeviceSegmentedReduce::Sum(nullptr, n_bytes, args..., stream));
344352
dh::CachingDeviceUVector<char> tmp(n_bytes);

0 commit comments

Comments
 (0)