11/* *
22 * Copyright 2025, XGBoost contributors
33 */
4- #include < thrust/reduce.h> // for reduce_by_key
4+ #include < thrust/reduce.h> // for reduce_by_key, reduce
55
66#include < cub/block/block_scan.cuh> // for BlockScan
77#include < cub/util_type.cuh> // for KeyValuePair
88#include < cub/warp/warp_reduce.cuh> // for WarpReduce
99#include < vector> // for vector
1010
1111#include " ../../common/cuda_context.cuh"
12+ #include " ../tree_view.h" // for MultiTargetTreeView
1213#include " ../updater_gpu_common.cuh" // for SumCallbackOp
1314#include " multi_evaluate_splits.cuh" // for MultiEvalauteSplitInputs, MultiEvaluateSplitSharedInputs
1415#include " quantiser.cuh" // for GradientQuantiser
@@ -221,7 +222,18 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
221222 dh::device_vector<MultiEvaluateSplitInputs> inputs{input};
222223 dh::device_vector<MultiExpandEntry> outputs (1 );
223224
224- this ->EvaluateSplits (ctx, dh::ToSpan (inputs), shared_inputs, dh::ToSpan (outputs));
225+ auto d_outputs = dh::ToSpan (outputs);
226+ this ->EvaluateSplits (ctx, dh::ToSpan (inputs), shared_inputs, d_outputs);
227+
228+ auto n_targets = shared_inputs.Targets ();
229+ dh::LaunchN (n_targets, ctx->CUDACtx ()->Stream (), [=] XGBOOST_DEVICE (std::size_t t) {
230+ auto weight = d_outputs[0 ].base_weight ;
231+ if (weight.empty ()) {
232+ return ;
233+ }
234+ weight[t] *= shared_inputs.param .learning_rate ;
235+ });
236+
225237 return outputs[0 ];
226238}
227239
@@ -326,6 +338,7 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
326338
327339 bool l = true , r = true ;
328340 GradientPairPrecise lg_fst, rg_fst;
341+ auto eta = shared_inputs.param .learning_rate ;
329342 for (bst_target_t t = 0 ; t < n_targets; ++t) {
330343 auto quantizer = d_roundings[t];
331344 auto sibling_sum = input.parent_sum [t] - node_sum[t];
@@ -337,15 +350,15 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
337350 if (best_split.dir == kRightDir ) {
338351 // forward pass, node_sum is the left sum
339352 lg = quantizer.ToFloatingPoint (node_sum[t]);
340- left_weight[t] = CalcWeight (shared_inputs.param , lg.GetGrad (), lg.GetHess ());
353+ left_weight[t] = CalcWeight (shared_inputs.param , lg.GetGrad (), lg.GetHess ()) * eta ;
341354 rg = quantizer.ToFloatingPoint (sibling_sum);
342- right_weight[t] = CalcWeight (shared_inputs.param , rg.GetGrad (), rg.GetHess ());
355+ right_weight[t] = CalcWeight (shared_inputs.param , rg.GetGrad (), rg.GetHess ()) * eta ;
343356 } else {
344357 // backward pass, node_sum is the right sum
345358 rg = quantizer.ToFloatingPoint (node_sum[t]);
346- right_weight[t] = CalcWeight (shared_inputs.param , rg.GetGrad (), rg.GetHess ());
359+ right_weight[t] = CalcWeight (shared_inputs.param , rg.GetGrad (), rg.GetHess ()) * eta ;
347360 lg = quantizer.ToFloatingPoint (sibling_sum);
348- left_weight[t] = CalcWeight (shared_inputs.param , lg.GetGrad (), lg.GetHess ());
361+ left_weight[t] = CalcWeight (shared_inputs.param , lg.GetGrad (), lg.GetHess ()) * eta ;
349362 }
350363
351364 if (t == 0 ) {
@@ -367,35 +380,50 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
367380}
368381
369382void MultiHistEvaluator::ApplyTreeSplit (Context const *ctx, RegTree const *p_tree,
370- MultiExpandEntry const &candidate) {
371- auto left_child = p_tree->LeftChild (candidate.nidx );
372- auto right_child = p_tree->RightChild (candidate.nidx );
373- bst_node_t max_node = std::max (left_child, right_child);
374- auto n_targets = candidate.base_weight .size ();
375-
383+ common::Span<MultiExpandEntry const > d_candidates,
384+ bst_target_t n_targets) {
385+ // Assign the node sums here, for the next evaluate split call.
386+ auto mt_tree = MultiTargetTreeView{ctx->Device (), p_tree};
387+ auto max_in_it = dh::MakeIndexTransformIter ([=] __device__ (std::size_t i) -> bst_node_t {
388+ return std::max (mt_tree.LeftChild (d_candidates[i].nidx ),
389+ mt_tree.RightChild (d_candidates[i].nidx ));
390+ });
391+ auto max_node = thrust::reduce (
392+ ctx->CUDACtx ()->CTP (), max_in_it, max_in_it + d_candidates.size (), 0 ,
393+ [=] XGBOOST_DEVICE (bst_node_t l, bst_node_t r) { return cuda::std::max (l, r); });
376394 this ->AllocNodeSum (max_node, n_targets);
377395
378- auto parent_sum = this ->GetNodeSum (candidate.nidx , n_targets);
379- auto left_sum = this ->GetNodeSum (left_child, n_targets);
380- auto right_sum = this ->GetNodeSum (right_child, n_targets);
381-
382- // Calculate node sums
383- // TODO(jiamingy): We need to batch the nodes
384- auto best_split = candidate.split ;
385-
386- auto node_sum = best_split.child_sum ;
387- dh::LaunchN (n_targets, ctx->CUDACtx ()->Stream (), [=] XGBOOST_DEVICE (std::size_t t) {
388- auto sibling_sum = parent_sum[t] - node_sum[t];
389- if (best_split.dir == kRightDir ) {
390- // forward pass, node_sum is the left sum
391- left_sum[t] = node_sum[t];
392- right_sum[t] = sibling_sum;
393- } else {
394- // backward pass, node_sum is the right sum
395- right_sum[t] = node_sum[t];
396- left_sum[t] = sibling_sum;
397- }
398- });
396+ auto node_sums = dh::ToSpan (this ->node_sums_ );
397+
398+ dh::LaunchN (n_targets * d_candidates.size (), ctx->CUDACtx ()->Stream (),
399+ [=] XGBOOST_DEVICE (std::size_t i) {
400+ auto get_node_sum = [&](bst_node_t nidx) {
401+ return GetNodeSumImpl (node_sums, nidx, n_targets);
402+ };
403+ auto nidx_in_set = i / n_targets;
404+ auto t = i % n_targets;
405+
406+ auto const &candidate = d_candidates[nidx_in_set];
407+ auto const &best_split = candidate.split ;
408+
409+ auto parent_sum = get_node_sum (candidate.nidx );
410+ // The child sum is a pointer to the scan buffer in this evaluator. Copy
411+ // the data into the node sum buffer before the next evaluation call.
412+ auto node_sum = best_split.child_sum ;
413+ auto left_sum = get_node_sum (mt_tree.LeftChild (candidate.nidx ));
414+ auto right_sum = get_node_sum (mt_tree.RightChild (candidate.nidx ));
415+
416+ auto sibling_sum = parent_sum[t] - node_sum[t];
417+ if (best_split.dir == kRightDir ) {
418+ // forward pass, node_sum is the left sum
419+ left_sum[t] = node_sum[t];
420+ right_sum[t] = sibling_sum;
421+ } else {
422+ // backward pass, node_sum is the right sum
423+ right_sum[t] = node_sum[t];
424+ left_sum[t] = sibling_sum;
425+ }
426+ });
399427}
400428
401429std::ostream &DebugPrintHistogram (std::ostream &os, common::Span<GradientPairInt64 const > node_hist,
0 commit comments