Skip to content

Commit e739915

Browse files
authored
Cleanup linalg kernels. (#11802)
- Split up two different transform kernels. - Handle the dispatching in the header instead of the client code. - Use thrust transform for transform kernels.
1 parent 73261fe commit e739915

23 files changed

+327
-230
lines changed

include/xgboost/linalg.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ enum Order : std::uint8_t {
273273
* some functions expect data types that can be used in everywhere (update prediction
274274
* cache for example).
275275
*/
276-
template <typename T, int32_t kDim>
276+
template <typename T, std::int32_t kDim>
277277
class TensorView {
278278
public:
279279
using ShapeT = std::size_t[kDim];
@@ -300,7 +300,7 @@ class TensorView {
300300
}
301301
}
302302

303-
template <size_t old_dim, size_t new_dim, int32_t D, typename I>
303+
template <size_t old_dim, size_t new_dim, std::int32_t D, typename I>
304304
LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D],
305305
detail::RangeTag<I> &&range) const {
306306
static_assert(new_dim < D);

plugin/sycl/common/linalg_op.cc

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <sycl/sycl.hpp>
1313

1414
namespace xgboost::sycl::linalg {
15-
1615
void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView<float const> indices,
1716
xgboost::common::OptionalWeights const& weights,
1817
xgboost::linalg::VectorView<float> bins) {
@@ -30,23 +29,4 @@ void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView<float const>
3029
});
3130
}).wait();
3231
}
33-
34-
void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
35-
sycl::DeviceManager device_manager;
36-
auto* qu = device_manager.GetQueue(ctx->Device());
37-
38-
qu->submit([&](::sycl::handler& cgh) {
39-
cgh.parallel_for<>(::sycl::range<1>(x.Size()),
40-
[=](::sycl::id<1> pid) {
41-
const size_t i = pid[0];
42-
const_cast<float&>(x(i)) *= mul;
43-
});
44-
}).wait();
45-
}
4632
} // namespace xgboost::sycl::linalg
47-
48-
namespace xgboost::linalg::sycl_impl {
49-
void VecScaMul(Context const* ctx, xgboost::linalg::VectorView<float> x, double mul) {
50-
xgboost::sycl::linalg::VecScaMul(ctx, x, mul);
51-
}
52-
} // namespace xgboost::linalg::sycl_impl

plugin/sycl/common/linalg_op.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
#include <vector>
99
#include <utility>
1010

11-
#include "../../../src/common/linalg_op.h"
12-
1311
#include "../data.h"
1412
#include "../device_manager.h"
1513

@@ -99,17 +97,5 @@ bool Validate(DeviceOrd device, TensorView<T, D> t, Fn&& fn) {
9997

10098
} // namespace linalg
10199
} // namespace sycl
102-
103-
namespace linalg {
104-
template <typename T, int32_t D, typename Fn>
105-
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
106-
if (ctx->IsSycl()) {
107-
sycl::linalg::ElementWiseKernel(t, fn);
108-
} else {
109-
ElementWiseKernelHost(t, ctx->Threads(), fn);
110-
}
111-
}
112-
113-
} // namespace linalg
114100
} // namespace xgboost
115101
#endif // PLUGIN_SYCL_COMMON_LINALG_OP_H_

plugin/sycl/tree/hist_updater.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "../../src/tree/common_row_partitioner.h"
1313

1414
#include "../common/hist_util.h"
15+
#include "xgboost/linalg.h"
1516
#include "../../src/collective/allreduce.h"
1617

1718
namespace xgboost {
@@ -34,8 +35,8 @@ void HistUpdater<GradientSumT>::ReduceHists(const std::vector<int>& sync_ids,
3435
qu_->memcpy(reduce_buffer_.data() + i * nbins, psrc, nbins*sizeof(GradientPairT)).wait();
3536
}
3637

37-
auto buffer_vec = linalg::MakeVec(reinterpret_cast<GradientSumT*>(reduce_buffer_.data()),
38-
2 * nbins * sync_ids.size());
38+
auto buffer_vec = ::xgboost::linalg::MakeVec(
39+
reinterpret_cast<GradientSumT*>(reduce_buffer_.data()), 2 * nbins * sync_ids.size());
3940
auto rc = collective::Allreduce(ctx_, buffer_vec, collective::Op::kSum);
4041
SafeColl(rc);
4142

@@ -361,10 +362,9 @@ void HistUpdater<GradientSumT>::Update(
361362
builder_monitor_.Stop("Update");
362363
}
363364

364-
template<typename GradientSumT>
365+
template <typename GradientSumT>
365366
bool HistUpdater<GradientSumT>::UpdatePredictionCache(
366-
const DMatrix* data,
367-
linalg::MatrixView<float> out_preds) {
367+
const DMatrix* data, ::xgboost::linalg::MatrixView<float> out_preds) {
368368
CHECK(out_preds.Device().IsSycl());
369369
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
370370
// conjunction with Update().
@@ -723,8 +723,8 @@ void HistUpdater<GradientSumT>::InitNewNode(int nid,
723723
}).wait_and_throw();
724724
}
725725
auto rc = collective::Allreduce(
726-
ctx_, linalg::MakeVec(reinterpret_cast<GradientSumT*>(&grad_stat), 2),
727-
collective::Op::kSum);
726+
ctx_, ::xgboost::linalg::MakeVec(reinterpret_cast<GradientSumT*>(&grad_stat), 2),
727+
collective::Op::kSum);
728728
SafeColl(rc);
729729
snode_host_[nid].stats = grad_stat;
730730
} else {

plugin/sycl/tree/hist_updater.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*!
2-
* Copyright 2017-2024 by Contributors
2+
* Copyright 2017-2025, XGBoost Contributors
33
* \file hist_updater.h
44
*/
55
#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_
@@ -8,6 +8,7 @@
88
#pragma GCC diagnostic push
99
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
1010
#pragma GCC diagnostic ignored "-W#pragma-messages"
11+
#include <xgboost/linalg.h> // for MatrixView
1112
#include <xgboost/tree_updater.h>
1213
#pragma GCC diagnostic pop
1314

@@ -80,8 +81,7 @@ class HistUpdater {
8081
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
8182
RegTree *p_tree);
8283

83-
bool UpdatePredictionCache(const DMatrix* data,
84-
linalg::MatrixView<float> p_out_preds);
84+
bool UpdatePredictionCache(const DMatrix* data, ::xgboost::linalg::MatrixView<float> p_out_preds);
8585

8686
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
8787
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);

plugin/sycl/tree/updater_quantile_hist.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,12 @@ void QuantileHistMaker::SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>* pim
6060
}
6161
}
6262

63-
template<typename GradientSumT>
64-
void QuantileHistMaker::CallUpdate(
65-
const std::unique_ptr<HistUpdater<GradientSumT>>& pimpl,
66-
xgboost::tree::TrainParam const *param,
67-
linalg::Matrix<GradientPair> *gpair,
68-
DMatrix *dmat,
69-
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
70-
const std::vector<RegTree *> &trees) {
63+
template <typename GradientSumT>
64+
void QuantileHistMaker::CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>> &pimpl,
65+
xgboost::tree::TrainParam const *param,
66+
::xgboost::linalg::Matrix<GradientPair> *gpair, DMatrix *dmat,
67+
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
68+
const std::vector<RegTree *> &trees) {
7169
for (auto tree : trees) {
7270
pimpl->Update(param, gmat_, *(gpair->Data()), dmat, out_position, tree);
7371
}
@@ -107,8 +105,8 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, GradientC
107105
p_last_dmat_ = dmat;
108106
}
109107

110-
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
111-
linalg::MatrixView<float> out_preds) {
108+
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
109+
::xgboost::linalg::MatrixView<float> out_preds) {
112110
if (param_.subsample < 1.0f) return false;
113111

114112
if (hist_precision_ == HistPrecision::fp32) {

plugin/sycl/tree/updater_quantile_hist.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class QuantileHistMaker: public TreeUpdater {
5353
const std::vector<RegTree*>& trees) override;
5454

5555
bool UpdatePredictionCache(const DMatrix* data,
56-
linalg::MatrixView<float> out_preds) override;
56+
::xgboost::linalg::MatrixView<float> out_preds) override;
5757

5858
void LoadConfig(Json const& in) override {
5959
auto const& config = get<Object const>(in);
@@ -90,7 +90,7 @@ class QuantileHistMaker: public TreeUpdater {
9090
template<typename GradientSumT>
9191
void CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>>& builder,
9292
xgboost::tree::TrainParam const *param,
93-
linalg::Matrix<GradientPair> *gpair,
93+
::xgboost::linalg::Matrix<GradientPair> *gpair,
9494
DMatrix *dmat,
9595
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
9696
const std::vector<RegTree *> &trees);

src/common/linalg_op.cu

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
/**
22
* Copyright 2025, XGBoost Contributors
33
*/
4-
#include <thrust/for_each.h> // for for_each_n
5-
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
6-
#include <thrust/scan.h> // for inclusive_scan
4+
#include <thrust/scan.h> // for inclusive_scan
75

86
#include <cstddef> // for size_t
97

@@ -15,11 +13,6 @@
1513
#include "xgboost/linalg.h" // for VectorView
1614

1715
namespace xgboost::linalg::cuda_impl {
18-
void VecScaMul(Context const* ctx, linalg::VectorView<float> x, double mul) {
19-
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), x.Size(),
20-
[=] XGBOOST_DEVICE(std::size_t i) mutable { x(i) = x(i) * mul; });
21-
}
22-
2316
void SmallHistogram(Context const* ctx, linalg::MatrixView<float const> indices,
2417
common::OptionalWeights const& d_weights, linalg::VectorView<float> bins) {
2518
auto n_bins = bins.Size();

src/common/linalg_op.cuh

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,21 @@
44
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
55
#define XGBOOST_COMMON_LINALG_OP_CUH_
66

7-
#include <cstdint> // for int32_t
8-
#include <cstdlib> // for size_t
9-
#include <tuple> // for apply
7+
#include <thrust/iterator/counting_iterator.h> // for counting_iterator
8+
#include <thrust/iterator/zip_iterator.h> // for make_zip_iterator
9+
#include <thrust/transform.h> // for transform
10+
11+
#include <cstdint> // for int32_t
12+
#include <cstdlib> // for size_t
13+
#include <cuda/std/iterator> // for iterator_traits
14+
#include <cuda/std/tuple> // for get
15+
#include <tuple> // for apply
1016

1117
#include "cuda_context.cuh"
1218
#include "device_helpers.cuh" // for LaunchN
13-
#include "linalg_op.h"
14-
#include "xgboost/context.h" // for Context
15-
#include "xgboost/linalg.h" // for TensorView
19+
#include "type.h" // for GetValueT
20+
#include "xgboost/context.h" // for Context
21+
#include "xgboost/linalg.h" // for TensorView
1622

1723
namespace xgboost::linalg {
1824
namespace cuda_impl {
@@ -40,17 +46,22 @@ struct ElementWiseImpl<T, 1> {
4046
template <typename T, std::int32_t D, typename Fn>
4147
void ElementWiseKernel(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
4248
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
43-
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
49+
ElementWiseImpl<T, D>{}(t, fn, s);
4450
}
4551

46-
void VecScaMul(Context const* ctx, linalg::VectorView<float> x, double mul);
47-
} // namespace cuda_impl
48-
49-
template <typename T, int32_t D, typename Fn>
50-
void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
52+
template <typename T, std::int32_t D, typename Fn>
53+
void TransformIdxKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
54+
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
55+
auto s = ctx->CUDACtx()->Stream();
5156
if (t.Contiguous()) {
5257
auto ptr = t.Values().data();
53-
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
58+
auto it =
59+
thrust::make_zip_iterator(thrust::make_counting_iterator(static_cast<std::size_t>(0)), ptr);
60+
using Tuple = typename cuda::std::iterator_traits<common::GetValueT<decltype(it)>>::value_type;
61+
thrust::transform(ctx->CUDACtx()->CTP(), it, it + t.Size(), ptr,
62+
[=] XGBOOST_DEVICE(Tuple const& tup) {
63+
return fn(cuda::std::get<0>(tup), cuda::std::get<1>(tup));
64+
});
5465
} else {
5566
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
5667
T& v = std::apply(t, UnravelIndex(i, t.Shape()));
@@ -59,44 +70,53 @@ void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nu
5970
}
6071
}
6172

62-
template <typename T, int32_t D, typename Fn>
63-
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
64-
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
65-
: ElementWiseKernelHost(t, ctx->Threads(), fn);
73+
template <typename T, std::int32_t D, typename Fn>
74+
void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
75+
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
76+
auto s = ctx->CUDACtx()->Stream();
77+
if (t.Contiguous()) {
78+
auto ptr = t.Values().data();
79+
thrust::transform(ctx->CUDACtx()->CTP(), ptr, ptr + t.Size(), ptr,
80+
[=] XGBOOST_DEVICE(T const& v) { return fn(v); });
81+
} else {
82+
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
83+
T& v = std::apply(t, UnravelIndex(i, t.Shape()));
84+
v = fn(v);
85+
});
86+
}
6687
}
88+
} // namespace cuda_impl
6789

6890
namespace detail {
69-
template <typename T, std::int32_t kDim>
91+
template <typename T, std::int32_t D>
7092
struct IterOp {
71-
TensorView<T, kDim> v;
72-
XGBOOST_DEVICE T& operator()(std::size_t i) {
73-
return std::apply(v, UnravelIndex(i, v.Shape()));
74-
}
93+
TensorView<T, D> v;
94+
XGBOOST_DEVICE T& operator()(std::size_t i) { return std::apply(v, UnravelIndex(i, v.Shape())); }
7595
};
7696
} // namespace detail
7797

7898
// naming: thrust begin
7999
// returns a thrust iterator for a tensor view.
80-
template <typename T, std::int32_t kDim>
81-
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
100+
template <typename T, std::int32_t D>
101+
auto tcbegin(TensorView<T, D> v) { // NOLINT
82102
return thrust::make_transform_iterator(
83103
thrust::make_counting_iterator(0ul),
84-
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
104+
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, D>{v});
85105
}
86106

87-
template <typename T, std::int32_t kDim>
88-
auto tcend(TensorView<T, kDim> v) { // NOLINT
107+
template <typename T, std::int32_t D>
108+
auto tcend(TensorView<T, D> v) { // NOLINT
89109
return tcbegin(v) + v.Size();
90110
}
91111

92-
template <typename T, std::int32_t kDim>
93-
auto tbegin(TensorView<T, kDim> v) { // NOLINT
112+
template <typename T, std::int32_t D>
113+
auto tbegin(TensorView<T, D> v) { // NOLINT
94114
return thrust::make_transform_iterator(thrust::make_counting_iterator(0ul),
95-
detail::IterOp<std::remove_const_t<T>, kDim>{v});
115+
detail::IterOp<std::remove_const_t<T>, D>{v});
96116
}
97117

98-
template <typename T, std::int32_t kDim>
99-
auto tend(TensorView<T, kDim> v) { // NOLINT
118+
template <typename T, std::int32_t D>
119+
auto tend(TensorView<T, D> v) { // NOLINT
100120
return tbegin(v) + v.Size();
101121
}
102122
} // namespace xgboost::linalg

0 commit comments

Comments
 (0)