@@ -352,9 +352,8 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
352352}
353353
354354void GPUHistEvaluator::LaunchEvaluateSplits (
355- bst_feature_t max_active_features,
356- common::Span<const EvaluateSplitInputs> d_inputs,
357- EvaluateSplitSharedInputs shared_inputs,
355+ Context const *ctx, bst_feature_t max_active_features,
356+ common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
358357 TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
359358 common::Span<DeviceSplitCandidate> out_splits) {
360359 if (need_sort_histogram_) {
@@ -367,28 +366,25 @@ void GPUHistEvaluator::LaunchEvaluateSplits(
367366
368367 // One block for each feature
369368 uint32_t constexpr kBlockThreads = 32 ;
370- dh::LaunchKernel {static_cast <uint32_t >(combined_num_features), kBlockThreads ,
371- 0 }(
372- EvaluateSplitsKernel<kBlockThreads >, max_active_features, d_inputs,
373- shared_inputs,
374- this ->SortedIdx (d_inputs.size (), shared_inputs.feature_values .size ()),
375- evaluator, dh::ToSpan (feature_best_splits));
369+ dh::LaunchKernel{static_cast <uint32_t >(combined_num_features), kBlockThreads , 0 , // NOLINT
370+ ctx->CUDACtx ()->Stream ()}(
371+ EvaluateSplitsKernel<kBlockThreads >, max_active_features, d_inputs, shared_inputs,
372+ this ->SortedIdx (d_inputs.size (), shared_inputs.feature_values .size ()), evaluator,
373+ dh::ToSpan (feature_best_splits));
376374
377375 // Reduce to get best candidate for left and right child over all features
378- auto reduce_offset =
379- dh::MakeTransformIterator<size_t >(thrust::make_counting_iterator (0llu),
380- [=] __device__ (size_t idx) -> size_t {
381- return idx * max_active_features;
382- });
376+ auto reduce_offset = dh::MakeTransformIterator<size_t >(
377+ thrust::make_counting_iterator (0llu),
378+ [=] __device__ (size_t idx) -> size_t { return idx * max_active_features; });
383379 size_t temp_storage_bytes = 0 ;
384380 auto num_segments = out_splits.size ();
385- cub::DeviceSegmentedReduce::Sum (nullptr , temp_storage_bytes, feature_best_splits. data (),
386- out_splits.data (), num_segments, reduce_offset ,
387- reduce_offset + 1 );
381+ dh::safe_cuda ( cub::DeviceSegmentedReduce::Sum (
382+ nullptr , temp_storage_bytes, feature_best_splits. data (), out_splits.data (), num_segments,
383+ reduce_offset, reduce_offset + 1 , ctx-> CUDACtx ()-> Stream ()) );
388384 dh::TemporaryArray<int8_t > temp (temp_storage_bytes);
389- cub::DeviceSegmentedReduce::Sum (temp. data (). get (), temp_storage_bytes, feature_best_splits. data (),
390- out_splits .data (), num_segments, reduce_offset ,
391- reduce_offset + 1 );
385+ dh::safe_cuda ( cub::DeviceSegmentedReduce::Sum (
386+ temp .data (). get (), temp_storage_bytes, feature_best_splits. data (), out_splits. data () ,
387+ num_segments, reduce_offset, reduce_offset + 1 , ctx-> CUDACtx ()-> Stream ()) );
392388}
393389
394390void GPUHistEvaluator::CopyToHost (const std::vector<bst_node_t > &nidx) {
@@ -414,8 +410,8 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_
414410
415411 dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage (d_inputs.size ());
416412 auto out_splits = dh::ToSpan (splits_out_storage);
417- this ->LaunchEvaluateSplits (max_active_features, d_inputs, shared_inputs,
418- evaluator, out_splits);
413+ this ->LaunchEvaluateSplits (ctx, max_active_features, d_inputs, shared_inputs, evaluator ,
414+ out_splits);
419415
420416 if (is_column_split_) {
421417 // With column-wise data split, we gather the split candidates from all the workers and find the
@@ -427,7 +423,7 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_
427423 all_candidates.subspan (collective::GetRank () * out_splits.size (), out_splits.size ());
428424 dh::safe_cuda (cudaMemcpyAsync (current_rank.data (), out_splits.data (),
429425 out_splits.size () * sizeof (DeviceSplitCandidate),
430- cudaMemcpyDeviceToDevice));
426+ cudaMemcpyDeviceToDevice, ctx-> CUDACtx ()-> Stream () ));
431427 auto rc = collective::Allgather (
432428 ctx, linalg::MakeVec (all_candidates.data (), all_candidates.size (), ctx->Device ()));
433429 collective::SafeColl (rc);
0 commit comments