1515
1616namespace xgboost ::tree::cuda_impl {
1717namespace {
18- __device__ bst_bin_t RevBinIdx (bst_bin_t gidx_end, bst_bin_t bin_idx) {
19- return gidx_end - bin_idx - 1 ;
18+ /* *
19+ * @brief Calculate the gradient index for the reverse pass
20+ *
21+ * @note All inputs are global across features.
22+ */
23+ __device__ bst_bin_t RevBinIdx (bst_bin_t gidx_begin, bst_bin_t gidx_end, bst_bin_t bin_idx) {
24+ return gidx_begin + (gidx_end - bin_idx - 1 );
2025}
2126
2227// Scan the histogram in 2 dim for all nodes
@@ -60,7 +65,7 @@ struct ScanHistogramAgent {
6065 __device__ void Backward (common::Span<GradientPairInt64 const > node_histogram,
6166 common::Span<GradientPairInt64> scan_result, bst_target_t t) {
6267 this ->ScanFeature (node_histogram, scan_result, t,
63- [&](bst_bin_t bin_idx) { return RevBinIdx (gidx_end, bin_idx); });
68+ [&](bst_bin_t bin_idx) { return RevBinIdx (gidx_begin, gidx_end, bin_idx); });
6469 }
6570};
6671} // namespace
@@ -112,7 +117,7 @@ struct EvaluateSplitAgent {
112117 template <std::int32_t d_step>
113118 __device__ void Numerical (MultiEvaluateSplitInputs const &node,
114119 MultiEvaluateSplitSharedInputs const &shared,
115- common::Span<GradientPairInt64 const > f_scan ,
120+ common::Span<GradientPairInt64 const > node_scan ,
116121 MultiSplitCandidate *best_split) {
117122 static_assert (d_step == +1 || d_step == -1 , " Invalid step." );
118123 // Calculate split gain for each bin
@@ -130,7 +135,7 @@ struct EvaluateSplitAgent {
130135 double gain = thread_active ? 0 : kNullGain ;
131136
132137 if (thread_active) {
133- auto scan_bin = f_scan .subspan (bin_idx * n_targets, n_targets);
138+ auto scan_bin = node_scan .subspan (bin_idx * n_targets, n_targets);
134139 for (bst_target_t t = 0 ; t < n_targets; ++t) {
135140 auto pg = roundings[t].ToFloatingPoint (node.parent_sum [t]);
136141 // left
@@ -155,7 +160,7 @@ struct EvaluateSplitAgent {
155160 // Update
156161 bst_bin_t split_gidx = bin_idx;
157162 if (d_step == -1 ) {
158- split_gidx = RevBinIdx (gidx_end, bin_idx);
163+ split_gidx = RevBinIdx (gidx_begin, gidx_end, bin_idx);
159164 }
160165 float min_fvalue = shared.min_values [fidx];
161166 float fvalue;
@@ -168,7 +173,7 @@ struct EvaluateSplitAgent {
168173 fvalue = shared.feature_values [split_gidx - 1 ];
169174 }
170175 }
171- auto scan_bin = f_scan .subspan (bin_idx * n_targets, n_targets);
176+ auto scan_bin = node_scan .subspan (bin_idx * n_targets, n_targets);
172177 // Missing values go to right in the forward pass, go to left in the backward pass.
173178 best_split->Update (gain, d_step == 1 ? kRightDir : kLeftDir , fvalue, fidx, scan_bin, false ,
174179 shared.param , shared.roundings );
@@ -190,80 +195,120 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
190195 using AgentT = EvaluateSplitAgent<kBlockThreads >;
191196 __shared__ typename AgentT::TempStorage temp_storage;
192197
193- auto fidx = blockIdx .x ;
194- EvaluateSplitAgent<kBlockThreads > agent{&temp_storage, blockIdx .x };
198+ const auto nidx = blockIdx .x / shared.Features ();
199+ bst_feature_t fidx = blockIdx .x % shared.Features ();
200+ AgentT agent{&temp_storage, fidx};
195201
196202 auto n_targets = shared.Targets ();
197203 // The number of bins in a feature
198204 auto f_hist_size =
199205 (shared.feature_segments [fidx + 1 ] - shared.feature_segments [fidx]) * n_targets;
200- // TODO(jiamingy): Support more than a single node
206+
207+ auto candidate_idx = nidx * shared.Features () + fidx;
201208
202209 if (shared.one_pass != MultiEvaluateSplitSharedInputs::kBackward ) {
203- auto forward = bin_scans[0 ].subspan (0 , nodes[0 ].histogram .size ());
204- auto f_scan = forward.subspan (shared.feature_segments [fidx] * n_targets, f_hist_size);
205- agent.template Numerical <+1 >(nodes[0 ], shared, f_scan, &out_candidates[fidx]);
210+ auto forward = bin_scans[nidx].subspan (0 , nodes[nidx].histogram .size ());
211+ agent.template Numerical <+1 >(nodes[nidx], shared, forward, &out_candidates[candidate_idx]);
206212 }
207213 if (shared.one_pass != MultiEvaluateSplitSharedInputs::kForward ) {
208- auto backward = bin_scans[ 0 ]. subspan (nodes[ 0 ]. histogram . size (), nodes[ 0 ]. histogram . size ());
209- auto f_scan = backward .subspan (shared. feature_segments [fidx] * n_targets, f_hist_size );
210- agent.template Numerical <-1 >(nodes[0 ], shared, f_scan , &out_candidates[fidx ]);
214+ auto backward =
215+ bin_scans[nidx] .subspan (nodes[nidx]. histogram . size (), nodes[nidx]. histogram . size () );
216+ agent.template Numerical <-1 >(nodes[nidx ], shared, backward , &out_candidates[candidate_idx ]);
211217 }
212218}
213219
214220[[nodiscard]] MultiExpandEntry MultiHistEvaluator::EvaluateSingleSplit (
215- Context const *ctx, MultiEvaluateSplitInputs input,
216- MultiEvaluateSplitSharedInputs shared_inputs) {
221+ Context const *ctx, MultiEvaluateSplitInputs const &input,
222+ MultiEvaluateSplitSharedInputs const &shared_inputs) {
223+ dh::device_vector<MultiEvaluateSplitInputs> inputs{input};
224+ dh::device_vector<MultiExpandEntry> outputs (1 );
225+
226+ this ->EvaluateSplits (ctx, dh::ToSpan (inputs), shared_inputs, dh::ToSpan (outputs));
227+ return outputs[0 ];
228+ }
229+
230+ void MultiHistEvaluator::EvaluateSplits (Context const *ctx,
231+ common::Span<MultiEvaluateSplitInputs const > d_inputs,
232+ MultiEvaluateSplitSharedInputs const &shared_inputs,
233+ common::Span<MultiExpandEntry> out_splits) {
217234 auto n_targets = shared_inputs.Targets ();
218235 CHECK_GE (n_targets, 2 );
219236 auto n_bins_per_feat_tar = shared_inputs.n_bins_per_feat_tar ;
220237 CHECK_GE (n_bins_per_feat_tar, 1 );
221238 auto n_features = shared_inputs.Features ();
222239 CHECK_GE (n_features, 1 );
223240
224- dh::device_vector<MultiEvaluateSplitInputs> inputs{input};
241+ std::uint32_t n_nodes = d_inputs.size ();
242+ CHECK_EQ (n_nodes, out_splits.size ());
243+
244+ if (n_nodes == 0 ) {
245+ return ;
246+ }
247+
248+ // Calculate total scan buffer size needed for all nodes
249+ auto node_hist_size = n_targets * n_features * n_bins_per_feat_tar;
250+ std::size_t total_hist_size = node_hist_size * n_nodes;
225251
226252 // Scan the histograms. One for forward and the other for backward.
227- this ->scan_buffer_ .resize (input. histogram . size () * 2 );
253+ this ->scan_buffer_ .resize (total_hist_size * 2 );
228254 thrust::fill (ctx->CUDACtx ()->CTP (), this ->scan_buffer_ .begin (), this ->scan_buffer_ .end (),
229255 GradientPairInt64{});
230- dh::device_vector<common::Span<GradientPairInt64>> scans{dh::ToSpan (this ->scan_buffer_ )};
231- std::uint32_t n_nodes = 1 ;
256+
257+ // Create spans for each node's scan results
258+ dh::device_vector<common::Span<GradientPairInt64>> scans (n_nodes);
259+ for (std::size_t nidx_in_set = 0 ; nidx_in_set < n_nodes; ++nidx_in_set) {
260+ scans[nidx_in_set] = dh::ToSpan (this ->scan_buffer_ )
261+ .subspan (nidx_in_set * node_hist_size * 2 , node_hist_size * 2 );
262+ }
263+
264+ // Launch histogram scan kernel
232265 dim3 grid{n_nodes, n_features, n_targets};
233266 std::uint32_t constexpr kBlockThreads = 32 ;
234267 dh::LaunchKernel{grid, kBlockThreads }( // NOLINT
235- ScanHistogramKernel<kBlockThreads >, dh::ToSpan (inputs) , shared_inputs, dh::ToSpan (scans));
268+ ScanHistogramKernel<kBlockThreads >, d_inputs , shared_inputs, dh::ToSpan (scans));
236269
237- dh::device_vector<MultiSplitCandidate> d_splits (n_features);
238- dh::LaunchKernel{n_features, kBlockThreads , 0 , ctx->CUDACtx ()->Stream ()}( // NOLINT
239- EvaluateSplitsKernel<kBlockThreads >, dh::ToSpan (inputs), shared_inputs, dh::ToSpan (scans),
270+ // Launch split evaluation kernel
271+ dh::device_vector<MultiSplitCandidate> d_splits (n_nodes * n_features);
272+ dh::LaunchKernel{n_nodes * n_features, kBlockThreads , 0 , ctx->CUDACtx ()->Stream ()}( // NOLINT
273+ EvaluateSplitsKernel<kBlockThreads >, d_inputs, shared_inputs, dh::ToSpan (scans),
240274 dh::ToSpan (d_splits));
241275
242- auto best_split = thrust::reduce (
243- ctx->CUDACtx ()->CTP (), d_splits.cbegin (), d_splits.cend (), MultiSplitCandidate{},
244- [] XGBOOST_DEVICE (MultiSplitCandidate const &lhs, MultiSplitCandidate const &rhs)
245- -> MultiSplitCandidate { return lhs.loss_chg > rhs.loss_chg ? lhs : rhs; });
276+ // Find best split for each node
277+ this ->weights_ .resize (n_nodes * n_targets * 3 );
278+ auto d_weights = dh::ToSpan (this ->weights_ );
246279
247- if (best_split.node_sum .empty ()) {
248- return {};
249- }
280+ dh::CachingDeviceUVector<float > d_parent_gains (n_nodes);
281+ dh::CachingDeviceUVector<std::int32_t > sum_zeros (n_nodes * 2 );
250282
251- // Calculate leaf weights from gradient sum
252- this ->weights_ .resize (n_targets * 3 );
253- auto d_weights = dh::ToSpan (this ->weights_ );
254- auto base_weight = d_weights.subspan (0 , n_targets);
255- auto left_weight = d_weights.subspan (n_targets, n_targets);
256- auto right_weight = d_weights.subspan (n_targets * 2 , n_targets);
283+ auto s_parent_gains = dh::ToSpan (d_parent_gains);
284+ auto s_sum_zeros = dh::ToSpan (sum_zeros);
285+ auto s_d_splits = dh::ToSpan (d_splits);
257286
258- dh::CachingDeviceUVector<float > d_parent_gain (1 );
259- dh::CachingDeviceUVector<std::int32_t > sum_zero (2 );
287+ // Process results for each node
288+ dh::LaunchN (n_nodes, ctx->CUDACtx ()->Stream (), [=] __device__ (std::size_t nidx_in_set) {
289+ auto input = d_inputs[nidx_in_set];
260290
261- auto s_pg = dh::ToSpan (d_parent_gain);
262- auto s_sum_zero = dh::ToSpan (sum_zero);
291+ // Find best split among all features for this node
292+ MultiSplitCandidate best_split{};
293+ for (bst_feature_t f = 0 ; f < n_features; ++f) {
294+ auto candidate = s_d_splits[nidx_in_set * n_features + f];
295+ if (candidate.loss_chg > best_split.loss_chg ) {
296+ best_split = candidate;
297+ }
298+ }
299+
300+ if (best_split.node_sum .empty ()) {
301+ // Invalid split
302+ out_splits[nidx_in_set] = {};
303+ return ;
304+ }
305+
306+ // Calculate weights for this node
307+ auto base_weight = d_weights.subspan (nidx_in_set * n_targets * 3 , n_targets);
308+ auto left_weight = d_weights.subspan (nidx_in_set * n_targets * 3 + n_targets, n_targets);
309+ auto right_weight = d_weights.subspan (nidx_in_set * n_targets * 3 + n_targets * 2 , n_targets);
263310
264- dh::LaunchN (inputs.size (), ctx->CUDACtx ()->Stream (), [=] __device__ (std::size_t i) {
265311 auto d_roundings = shared_inputs.roundings ;
266- // the data inside the split candidates references the scan result.
267312 auto node_sum = best_split.node_sum ;
268313
269314 float parent_gain = 0 ;
@@ -276,15 +321,15 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
276321 base_weight[t] = CalcWeight (shared_inputs.param , g.GetGrad (), g.GetHess ());
277322 parent_gain += -base_weight[t] * ThresholdL1 (g.GetGrad (), shared_inputs.param .reg_alpha );
278323 }
279- s_pg[ 0 ] = parent_gain;
324+ s_parent_gains[nidx_in_set ] = parent_gain;
280325
281326 bool l = true , r = true ;
282327 for (bst_target_t t = 0 ; t < n_targets; ++t) {
283328 auto quantizer = d_roundings[t];
284329 auto sibling_sum = input.parent_sum [t] - node_sum[t];
285330
286- l = l && (node_sum[t].GetQuantisedHess () - . 0 == . 0 );
287- r = r && (sibling_sum.GetQuantisedHess () - . 0 == . 0 );
331+ l = l && (node_sum[t].GetQuantisedHess () == 0 );
332+ r = r && (sibling_sum.GetQuantisedHess () == 0 );
288333
289334 if (best_split.dir == kRightDir ) {
290335 // forward pass, node_sum is the left sum
@@ -299,43 +344,72 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
299344 auto lg = quantizer.ToFloatingPoint (sibling_sum);
300345 left_weight[t] = CalcWeight (shared_inputs.param , lg.GetGrad (), lg.GetHess ());
301346 }
347+ }
348+
349+ s_sum_zeros[nidx_in_set * 2 ] = l;
350+ s_sum_zeros[nidx_in_set * 2 + 1 ] = r;
351+
352+ // Set up the output entry
353+ out_splits[nidx_in_set] = {input.nidx , input.depth , best_split,
354+ base_weight, left_weight, right_weight};
355+ out_splits[nidx_in_set].split .loss_chg -= parent_gain;
302356
303- s_sum_zero[ 0 ] = l;
304- s_sum_zero[ 1 ] = r ;
357+ if (l || r) {
358+ out_splits[nidx_in_set ] = {} ;
305359 }
306360 });
307- // Copy the result back to the host.
308- float parent_gain = 0 ;
309- dh::safe_cuda ( cudaMemcpyAsync (&parent_gain, d_parent_gain. data (), sizeof (parent_gain) ,
310- cudaMemcpyDefault, ctx-> CUDACtx ()-> Stream ()));
311- best_split. loss_chg -= parent_gain ;
312-
313- std::vector<std:: int32_t > h_sum_zero (s_sum_zero. size () );
314- dh::safe_cuda ( cudaMemcpyAsync (h_sum_zero. data (), s_sum_zero. data (), s_sum_zero. size_bytes (),
315- cudaMemcpyDefault, ctx-> CUDACtx ()-> Stream ()) );
316- if (h_sum_zero[ 0 ] || h_sum_zero[ 1 ]) {
317- return {};
318- }
361+ }
362+
363+ void MultiHistEvaluator::ApplyTreeSplit (Context const *ctx, RegTree const *p_tree ,
364+ MultiExpandEntry const &candidate) {
365+ auto n_targets = p_tree-> NumTargets () ;
366+
367+ auto left_child = p_tree-> LeftChild (candidate. nidx );
368+ auto right_child = p_tree-> RightChild (candidate. nidx );
369+ bst_node_t max_node = std::max (left_child, right_child );
370+ this -> AllocNodeSum (max_node, n_targets);
371+
372+ auto parent_sum = this -> GetNodeSum (candidate. nidx , n_targets);
319373
320- MultiExpandEntry entry{input.nidx , input.depth , best_split,
321- base_weight, left_weight, right_weight};
322- return entry;
374+ auto left_sum = this ->GetNodeSum (left_child, n_targets);
375+ auto right_sum = this ->GetNodeSum (right_child, n_targets);
376+
377+ // Calculate node sums
378+ // TODO(jiamingy): We need to batch the targets and nodes
379+ auto best_split = candidate.split ;
380+ auto node_sum = best_split.node_sum ;
381+ dh::LaunchN (1 , ctx->CUDACtx ()->Stream (), [=] XGBOOST_DEVICE (std::size_t ) {
382+ for (bst_target_t t = 0 ; t < n_targets; ++t) {
383+ auto sibling_sum = parent_sum[t] - node_sum[t];
384+ if (best_split.dir == kRightDir ) {
385+ // forward pass, node_sum is the left sum
386+ left_sum[t] = node_sum[t];
387+ right_sum[t] = sibling_sum;
388+ } else {
389+ // backward pass, node_sum is the right sum
390+ right_sum[t] = node_sum[t];
391+ left_sum[t] = sibling_sum;
392+ }
393+ }
394+ });
323395}
324396
325- void DebugPrintHistogram (common::Span<GradientPairInt64 const > node_hist,
326- common::Span<GradientQuantiser const > roundings, bst_target_t n_targets) {
397+ std::ostream &DebugPrintHistogram (std::ostream &os, common::Span<GradientPairInt64 const > node_hist,
398+ common::Span<GradientQuantiser const > roundings,
399+ bst_target_t n_targets) {
327400 std::vector<GradientQuantiser> h_roundings;
328401 thrust::copy (dh::tcbegin (roundings), dh::tcend (roundings), std::back_inserter (h_roundings));
329402 dh::CopyDeviceSpanToVector (&h_roundings, roundings);
330403
331404 std::vector<GradientPairInt64> h_node_hist (node_hist.size ());
332405 dh::CopyDeviceSpanToVector (&h_node_hist, node_hist);
333406 for (bst_target_t t = 0 ; t < n_targets; ++t) {
334- std::cout << " target :" << t << std::endl;
407+ os << " Target :" << t << std::endl;
335408 for (std::size_t i = t; i < h_node_hist.size () / n_targets; i += n_targets) {
336- std::cout << h_roundings[t].ToFloatingPoint (h_node_hist[i]) << " , " ;
409+ os << h_roundings[t].ToFloatingPoint (h_node_hist[i]) << " , " ;
337410 }
338- std::cout << std::endl;
411+ os << std::endl;
339412 }
413+ return os;
340414}
341415} // namespace xgboost::tree::cuda_impl
0 commit comments