44#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
55#define XGBOOST_COMMON_LINALG_OP_CUH_
66
7- #include < cstdint> // for int32_t
8- #include < cstdlib> // for size_t
9- #include < tuple> // for apply
7+ #include < thrust/iterator/counting_iterator.h> // for counting_iterator
8+ #include < thrust/iterator/zip_iterator.h> // for make_zip_iterator
9+ #include < thrust/transform.h> // for transform
10+
11+ #include < cstdint> // for int32_t
12+ #include < cstdlib> // for size_t
13+ #include < cuda/std/iterator> // for iterator_traits
14+ #include < cuda/std/tuple> // for get
15+ #include < tuple> // for apply
1016
1117#include " cuda_context.cuh"
1218#include " device_helpers.cuh" // for LaunchN
13- #include " linalg_op .h"
14- #include " xgboost/context.h" // for Context
15- #include " xgboost/linalg.h" // for TensorView
19+ #include " type .h" // for GetValueT
20+ #include " xgboost/context.h" // for Context
21+ #include " xgboost/linalg.h" // for TensorView
1622
1723namespace xgboost ::linalg {
1824namespace cuda_impl {
@@ -40,17 +46,22 @@ struct ElementWiseImpl<T, 1> {
4046template <typename T, std::int32_t D, typename Fn>
4147void ElementWiseKernel (TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr ) {
4248 dh::safe_cuda (cudaSetDevice (t.Device ().ordinal ));
43- cuda_impl:: ElementWiseImpl<T, D>{}(t, fn, s);
49+ ElementWiseImpl<T, D>{}(t, fn, s);
4450}
4551
46- void VecScaMul (Context const * ctx, linalg::VectorView<float > x, double mul);
47- } // namespace cuda_impl
48-
49- template <typename T, int32_t D, typename Fn>
50- void ElementWiseTransformDevice (TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr ) {
52+ template <typename T, std::int32_t D, typename Fn>
53+ void TransformIdxKernel (Context const * ctx, TensorView<T, D> t, Fn&& fn) {
54+ dh::safe_cuda (cudaSetDevice (t.Device ().ordinal ));
55+ auto s = ctx->CUDACtx ()->Stream ();
5156 if (t.Contiguous ()) {
5257 auto ptr = t.Values ().data ();
53- dh::LaunchN (t.Size (), s, [=] __device__ (size_t i) { ptr[i] = fn (i, ptr[i]); });
58+ auto it =
59+ thrust::make_zip_iterator (thrust::make_counting_iterator (static_cast <std::size_t >(0 )), ptr);
60+ using Tuple = typename cuda::std::iterator_traits<common::GetValueT<decltype (it)>>::value_type;
61+ thrust::transform (ctx->CUDACtx ()->CTP (), it, it + t.Size (), ptr,
62+ [=] XGBOOST_DEVICE (Tuple const & tup) {
63+ return fn (cuda::std::get<0 >(tup), cuda::std::get<1 >(tup));
64+ });
5465 } else {
5566 dh::LaunchN (t.Size (), s, [=] __device__ (size_t i) mutable {
5667 T& v = std::apply (t, UnravelIndex (i, t.Shape ()));
@@ -59,44 +70,53 @@ void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nu
5970 }
6071}
6172
62- template <typename T, int32_t D, typename Fn>
63- void ElementWiseKernel (Context const * ctx, TensorView<T, D> t, Fn&& fn) {
64- ctx->IsCUDA () ? cuda_impl::ElementWiseKernel (t, fn)
65- : ElementWiseKernelHost (t, ctx->Threads (), fn);
73+ template <typename T, std::int32_t D, typename Fn>
74+ void TransformKernel (Context const * ctx, TensorView<T, D> t, Fn&& fn) {
75+ dh::safe_cuda (cudaSetDevice (t.Device ().ordinal ));
76+ auto s = ctx->CUDACtx ()->Stream ();
77+ if (t.Contiguous ()) {
78+ auto ptr = t.Values ().data ();
79+ thrust::transform (ctx->CUDACtx ()->CTP (), ptr, ptr + t.Size (), ptr,
80+ [=] XGBOOST_DEVICE (T const & v) { return fn (v); });
81+ } else {
82+ dh::LaunchN (t.Size (), s, [=] __device__ (size_t i) mutable {
83+ T& v = std::apply (t, UnravelIndex (i, t.Shape ()));
84+ v = fn (v);
85+ });
86+ }
6687}
88+ } // namespace cuda_impl
6789
6890namespace detail {
69- template <typename T, std::int32_t kDim >
91+ template <typename T, std::int32_t D >
7092struct IterOp {
71- TensorView<T, kDim > v;
72- XGBOOST_DEVICE T& operator ()(std::size_t i) {
73- return std::apply (v, UnravelIndex (i, v.Shape ()));
74- }
93+ TensorView<T, D> v;
94+ XGBOOST_DEVICE T& operator ()(std::size_t i) { return std::apply (v, UnravelIndex (i, v.Shape ())); }
7595};
7696} // namespace detail
7797
7898// naming: thrust begin
7999// returns a thrust iterator for a tensor view.
80- template <typename T, std::int32_t kDim >
81- auto tcbegin (TensorView<T, kDim > v) { // NOLINT
100+ template <typename T, std::int32_t D >
101+ auto tcbegin (TensorView<T, D > v) { // NOLINT
82102 return thrust::make_transform_iterator (
83103 thrust::make_counting_iterator (0ul ),
84- detail::IterOp<std::add_const_t <std::remove_const_t <T>>, kDim >{v});
104+ detail::IterOp<std::add_const_t <std::remove_const_t <T>>, D >{v});
85105}
86106
87- template <typename T, std::int32_t kDim >
88- auto tcend (TensorView<T, kDim > v) { // NOLINT
107+ template <typename T, std::int32_t D >
108+ auto tcend (TensorView<T, D > v) { // NOLINT
89109 return tcbegin (v) + v.Size ();
90110}
91111
92- template <typename T, std::int32_t kDim >
93- auto tbegin (TensorView<T, kDim > v) { // NOLINT
112+ template <typename T, std::int32_t D >
113+ auto tbegin (TensorView<T, D > v) { // NOLINT
94114 return thrust::make_transform_iterator (thrust::make_counting_iterator (0ul ),
95- detail::IterOp<std::remove_const_t <T>, kDim >{v});
115+ detail::IterOp<std::remove_const_t <T>, D >{v});
96116}
97117
98- template <typename T, std::int32_t kDim >
99- auto tend (TensorView<T, kDim > v) { // NOLINT
118+ template <typename T, std::int32_t D >
119+ auto tend (TensorView<T, D > v) { // NOLINT
100120 return tbegin (v) + v.Size ();
101121}
102122} // namespace xgboost::linalg
0 commit comments