@@ -202,10 +202,6 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
202202 AgentT agent{&temp_storage, fidx};
203203
204204 auto n_targets = shared.Targets ();
205- // The number of bins in a feature
206- auto f_hist_size =
207- (shared.feature_segments [fidx + 1 ] - shared.feature_segments [fidx]) * n_targets;
208-
209205 auto candidate_idx = nidx * shared.Features () + fidx;
210206
211207 if (shared.one_pass != MultiEvaluateSplitSharedInputs::kBackward ) {
@@ -256,11 +252,12 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
256252 GradientPairInt64{});
257253
258254 // Create spans for each node's scan results
259- dh::device_vector <common::Span<GradientPairInt64>> scans (n_nodes);
255+ std::vector <common::Span<GradientPairInt64>> h_scans (n_nodes);
260256 for (std::size_t nidx_in_set = 0 ; nidx_in_set < n_nodes; ++nidx_in_set) {
261- scans [nidx_in_set] = dh::ToSpan (this ->scan_buffer_ )
262- .subspan (nidx_in_set * node_hist_size * 2 , node_hist_size * 2 );
257+ h_scans [nidx_in_set] = dh::ToSpan (this ->scan_buffer_ )
258+ .subspan (nidx_in_set * node_hist_size * 2 , node_hist_size * 2 );
263259 }
260+ dh::device_vector<common::Span<GradientPairInt64>> scans (h_scans);
264261
265262 // Launch histogram scan kernel
266263 dim3 grid{n_nodes, n_features, n_targets};
@@ -328,32 +325,40 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
328325 s_parent_gains[nidx_in_set] = parent_gain;
329326
330327 bool l = true , r = true ;
328+ GradientPairPrecise lg_fst, rg_fst;
331329 for (bst_target_t t = 0 ; t < n_targets; ++t) {
332330 auto quantizer = d_roundings[t];
333331 auto sibling_sum = input.parent_sum [t] - node_sum[t];
334332
335333 l = l && (node_sum[t].GetQuantisedHess () == 0 );
336334 r = r && (sibling_sum.GetQuantisedHess () == 0 );
337335
336+ GradientPairPrecise lg, rg;
338337 if (best_split.dir == kRightDir ) {
339338 // forward pass, node_sum is the left sum
340- auto lg = quantizer.ToFloatingPoint (node_sum[t]);
339+ lg = quantizer.ToFloatingPoint (node_sum[t]);
341340 left_weight[t] = CalcWeight (shared_inputs.param , lg.GetGrad (), lg.GetHess ());
342- auto rg = quantizer.ToFloatingPoint (sibling_sum);
341+ rg = quantizer.ToFloatingPoint (sibling_sum);
343342 right_weight[t] = CalcWeight (shared_inputs.param , rg.GetGrad (), rg.GetHess ());
344343 } else {
345344 // backward pass, node_sum is the right sum
346- auto rg = quantizer.ToFloatingPoint (node_sum[t]);
345+ rg = quantizer.ToFloatingPoint (node_sum[t]);
347346 right_weight[t] = CalcWeight (shared_inputs.param , rg.GetGrad (), rg.GetHess ());
348- auto lg = quantizer.ToFloatingPoint (sibling_sum);
347+ lg = quantizer.ToFloatingPoint (sibling_sum);
349348 left_weight[t] = CalcWeight (shared_inputs.param , lg.GetGrad (), lg.GetHess ());
350349 }
350+
351+ if (t == 0 ) {
352+ lg_fst = lg;
353+ rg_fst = rg;
354+ }
351355 }
352356
353357 // Set up the output entry
354358 out_splits[nidx_in_set] = {input.nidx , input.depth , best_split,
355359 base_weight, left_weight, right_weight};
356360 out_splits[nidx_in_set].split .loss_chg -= parent_gain;
361+ out_splits[nidx_in_set].UpdateFirstHessian (lg_fst, rg_fst);
357362
358363 if (l || r) {
359364 out_splits[nidx_in_set] = {};
0 commit comments