Skip to content

Commit fd4f50c

Browse files
committed
Don't set device.
1 parent de818bd commit fd4f50c

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/common/linalg_op.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ struct ElementWiseImpl<T, 1> {
4545

4646
template <typename T, std::int32_t D, typename Fn>
4747
void ElementWiseKernel(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
48-
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
4948
ElementWiseImpl<T, D>{}(t, fn, s);
5049
}
5150

src/common/linalg_op.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ auto end(TensorView<T, D>& v) { // NOLINT
126126

127127
namespace detail {
128128
using SysTagImpl = std::int32_t;
129-
129+
// Magic for complying with the ODR.
130130
#if defined(__CUDACC__)
131131
constexpr SysTagImpl SysTag() { return 0; }
132132
#elif defined(XGBOOST_USE_SYCL)
@@ -163,6 +163,7 @@ void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
163163
#else
164164
template <typename T, std::int32_t D, typename Fn, auto _tag = detail::SysTag()>
165165
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
166+
CHECK(ctx->IsCPU());
166167
ctx->DispatchDevice([&] { cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
167168
[&] { LOG(FATAL) << "Invalid TU"; });
168169
}
@@ -199,6 +200,7 @@ void TransformIdxKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
199200
#else
200201
template <typename T, std::int32_t D, typename Fn, auto _tag = detail::SysTag()>
201202
void TransformIdxKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
203+
CHECK(ctx->IsCPU());
202204
ctx->DispatchDevice(
203205
[&] { cpu_impl::TransformIdxKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
204206
[&] { LOG(FATAL) << "Invalid TU."; });
@@ -229,6 +231,7 @@ void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
229231
#else
230232
template <typename T, std::int32_t D, typename Fn, auto _tag = detail::SysTag()>
231233
void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
234+
CHECK(ctx->IsCPU());
232235
ctx->DispatchDevice([&] { cpu_impl::TransformKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
233236
[&] { LOG(FATAL) << "Invalid TU."; });
234237
}

0 commit comments

Comments
 (0)