Skip to content

Commit 64c8eef

Browse files
authored
Merge branch 'master' into fix-r-checks
2 parents 88d18b9 + 36d2b42 commit 64c8eef

File tree

8 files changed

+148
-72
lines changed

8 files changed

+148
-72
lines changed

.github/workflows/jvm_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ jobs:
149149
uses: actions/cache@v4
150150
with:
151151
path: ~/.m2
152-
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
153-
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
152+
key: ${{ runner.os }}-m2-${{ hashFiles('/jvm-packages/pom.xml') }}
153+
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('/jvm-packages/pom.xml') }}
154154
- name: Test XGBoost4J (Core) on macos
155155
if: matrix.os == 'macos-15-intel'
156156
run: |

python-package/xgboost/testing/multi_target.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for multi-target training."""
22

3+
# pylint: disable=unbalanced-tuple-unpacking
34
from typing import Dict, Optional, Tuple
45

56
import numpy as np
@@ -29,6 +30,7 @@ def run_multiclass(device: Device, learning_rate: Optional[float]) -> None:
2930
128, n_features=12, n_informative=10, n_classes=4, random_state=2025
3031
)
3132
clf = XGBClassifier(
33+
debug_synchronize=True,
3234
multi_strategy="multi_output_tree",
3335
callbacks=[ResetStrategy()],
3436
n_estimators=10,
@@ -47,9 +49,9 @@ def run_multiclass(device: Device, learning_rate: Optional[float]) -> None:
4749

4850
def run_multilabel(device: Device, learning_rate: Optional[float]) -> None:
4951
"""Use vector leaf for multi-label classification models."""
50-
# pylint: disable=unbalanced-tuple-unpacking
5152
X, y = make_multilabel_classification(128, random_state=2025)
5253
clf = XGBClassifier(
54+
debug_synchronize=True,
5355
multi_strategy="multi_output_tree",
5456
callbacks=[ResetStrategy()],
5557
n_estimators=10,
@@ -103,7 +105,7 @@ def run_reduced_grad(device: Device) -> None:
103105
"""Basic test for using reduced gradient for tree splits."""
104106
import cupy as cp
105107

106-
X, y = make_regression( # pylint: disable=unbalanced-tuple-unpacking
108+
X, y = make_regression(
107109
n_samples=1024, n_features=16, random_state=1994, n_targets=5
108110
)
109111
Xy = QuantileDMatrix(X, y)
@@ -114,6 +116,7 @@ def run_test(
114116
evals_result: Dict[str, Dict] = {}
115117
booster = train(
116118
{
119+
"debug_synchronize": True,
117120
"device": device,
118121
"multi_strategy": "multi_output_tree",
119122
"learning_rate": 1,
@@ -184,7 +187,7 @@ def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
184187
Xs = []
185188
ys = []
186189
for i in range(n_batches):
187-
X_i, y_i = make_regression( # pylint: disable=unbalanced-tuple-unpacking
190+
X_i, y_i = make_regression(
188191
n_samples=4096, n_features=8, random_state=(i + 1), n_targets=n_targets
189192
)
190193
Xs.append(asarray(X_i))
@@ -245,3 +248,33 @@ def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
245248
evals_result_0["Train"]["rmse"], evals_result_2["Train"]["rmse"]
246249
)
247250
assert_allclose(device, booster_0.inplace_predict(X), booster_2.inplace_predict(X))
251+
252+
253+
def run_eta(device: Device) -> None:
254+
"""Test for learning rate."""
255+
X, y = make_regression(512, 16, random_state=2025, n_targets=3)
256+
257+
def run(obj: Optional[Objective]) -> None:
258+
params = {
259+
"device": device,
260+
"multi_strategy": "multi_output_tree",
261+
"learning_rate": 1.0,
262+
"debug_synchronize": True,
263+
"base_score": 0.0,
264+
}
265+
Xy = QuantileDMatrix(X, y)
266+
booster_0 = train(params, Xy, num_boost_round=1, obj=obj)
267+
params["learning_rate"] = 0.1
268+
booster_1 = train(params, Xy, num_boost_round=1, obj=obj)
269+
params["learning_rate"] = 2.0
270+
booster_2 = train(params, Xy, num_boost_round=1, obj=obj)
271+
272+
predt_0 = booster_0.predict(Xy)
273+
predt_1 = booster_1.predict(Xy)
274+
predt_2 = booster_2.predict(Xy)
275+
276+
np.testing.assert_allclose(predt_0, predt_1 * 10, rtol=1e-6)
277+
np.testing.assert_allclose(predt_0 * 2, predt_2, rtol=1e-6)
278+
279+
run(None)
280+
run(LsObj0())

src/tree/gpu_hist/expand_entry.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ struct MultiExpandEntry {
132132
bst_node_t depth{0};
133133
MultiSplitCandidate split;
134134

135-
common::Span<float const> base_weight;
135+
common::Span<float> base_weight;
136136
common::Span<float const> left_weight;
137137
common::Span<float const> right_weight;
138138

src/tree/gpu_hist/leaf_sum.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ void LeafWeight(Context const* ctx, GPUTrainingParam const& param,
8888
dh::LaunchN(grad_sum.Size(), ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t i) mutable {
8989
auto [nidx_in_set, t] = linalg::UnravelIndex(i, grad_sum.Shape());
9090
auto g = roundings[t].ToFloatingPoint(grad_sum(nidx_in_set, t));
91-
out_weights(nidx_in_set, t) = CalcWeight(param, g.GetGrad(), g.GetHess());
91+
out_weights(nidx_in_set, t) = CalcWeight(param, g.GetGrad(), g.GetHess()) * param.learning_rate;
9292
});
9393
}
9494
} // namespace xgboost::tree::cuda_impl

src/tree/gpu_hist/multi_evaluate_splits.cu

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
/**
22
* Copyright 2025, XGBoost contributors
33
*/
4-
#include <thrust/reduce.h> // for reduce_by_key
4+
#include <thrust/reduce.h> // for reduce_by_key, reduce
55

66
#include <cub/block/block_scan.cuh> // for BlockScan
77
#include <cub/util_type.cuh> // for KeyValuePair
88
#include <cub/warp/warp_reduce.cuh> // for WarpReduce
99
#include <vector> // for vector
1010

1111
#include "../../common/cuda_context.cuh"
12+
#include "../tree_view.h" // for MultiTargetTreeView
1213
#include "../updater_gpu_common.cuh" // for SumCallbackOp
1314
#include "multi_evaluate_splits.cuh" // for MultiEvalauteSplitInputs, MultiEvaluateSplitSharedInputs
1415
#include "quantiser.cuh" // for GradientQuantiser
@@ -221,7 +222,18 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
221222
dh::device_vector<MultiEvaluateSplitInputs> inputs{input};
222223
dh::device_vector<MultiExpandEntry> outputs(1);
223224

224-
this->EvaluateSplits(ctx, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(outputs));
225+
auto d_outputs = dh::ToSpan(outputs);
226+
this->EvaluateSplits(ctx, dh::ToSpan(inputs), shared_inputs, d_outputs);
227+
228+
auto n_targets = shared_inputs.Targets();
229+
dh::LaunchN(n_targets, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t t) {
230+
auto weight = d_outputs[0].base_weight;
231+
if (weight.empty()) {
232+
return;
233+
}
234+
weight[t] *= shared_inputs.param.learning_rate;
235+
});
236+
225237
return outputs[0];
226238
}
227239

@@ -326,6 +338,7 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
326338

327339
bool l = true, r = true;
328340
GradientPairPrecise lg_fst, rg_fst;
341+
auto eta = shared_inputs.param.learning_rate;
329342
for (bst_target_t t = 0; t < n_targets; ++t) {
330343
auto quantizer = d_roundings[t];
331344
auto sibling_sum = input.parent_sum[t] - node_sum[t];
@@ -337,15 +350,15 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
337350
if (best_split.dir == kRightDir) {
338351
// forward pass, node_sum is the left sum
339352
lg = quantizer.ToFloatingPoint(node_sum[t]);
340-
left_weight[t] = CalcWeight(shared_inputs.param, lg.GetGrad(), lg.GetHess());
353+
left_weight[t] = CalcWeight(shared_inputs.param, lg.GetGrad(), lg.GetHess()) * eta;
341354
rg = quantizer.ToFloatingPoint(sibling_sum);
342-
right_weight[t] = CalcWeight(shared_inputs.param, rg.GetGrad(), rg.GetHess());
355+
right_weight[t] = CalcWeight(shared_inputs.param, rg.GetGrad(), rg.GetHess()) * eta;
343356
} else {
344357
// backward pass, node_sum is the right sum
345358
rg = quantizer.ToFloatingPoint(node_sum[t]);
346-
right_weight[t] = CalcWeight(shared_inputs.param, rg.GetGrad(), rg.GetHess());
359+
right_weight[t] = CalcWeight(shared_inputs.param, rg.GetGrad(), rg.GetHess()) * eta;
347360
lg = quantizer.ToFloatingPoint(sibling_sum);
348-
left_weight[t] = CalcWeight(shared_inputs.param, lg.GetGrad(), lg.GetHess());
361+
left_weight[t] = CalcWeight(shared_inputs.param, lg.GetGrad(), lg.GetHess()) * eta;
349362
}
350363

351364
if (t == 0) {
@@ -367,35 +380,50 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
367380
}
368381

369382
void MultiHistEvaluator::ApplyTreeSplit(Context const *ctx, RegTree const *p_tree,
370-
MultiExpandEntry const &candidate) {
371-
auto left_child = p_tree->LeftChild(candidate.nidx);
372-
auto right_child = p_tree->RightChild(candidate.nidx);
373-
bst_node_t max_node = std::max(left_child, right_child);
374-
auto n_targets = candidate.base_weight.size();
375-
383+
common::Span<MultiExpandEntry const> d_candidates,
384+
bst_target_t n_targets) {
385+
// Assign the node sums here, for the next evaluate split call.
386+
auto mt_tree = MultiTargetTreeView{ctx->Device(), p_tree};
387+
auto max_in_it = dh::MakeIndexTransformIter([=] __device__(std::size_t i) -> bst_node_t {
388+
return std::max(mt_tree.LeftChild(d_candidates[i].nidx),
389+
mt_tree.RightChild(d_candidates[i].nidx));
390+
});
391+
auto max_node = thrust::reduce(
392+
ctx->CUDACtx()->CTP(), max_in_it, max_in_it + d_candidates.size(), 0,
393+
[=] XGBOOST_DEVICE(bst_node_t l, bst_node_t r) { return cuda::std::max(l, r); });
376394
this->AllocNodeSum(max_node, n_targets);
377395

378-
auto parent_sum = this->GetNodeSum(candidate.nidx, n_targets);
379-
auto left_sum = this->GetNodeSum(left_child, n_targets);
380-
auto right_sum = this->GetNodeSum(right_child, n_targets);
381-
382-
// Calculate node sums
383-
// TODO(jiamingy): We need to batch the nodes
384-
auto best_split = candidate.split;
385-
386-
auto node_sum = best_split.child_sum;
387-
dh::LaunchN(n_targets, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t t) {
388-
auto sibling_sum = parent_sum[t] - node_sum[t];
389-
if (best_split.dir == kRightDir) {
390-
// forward pass, node_sum is the left sum
391-
left_sum[t] = node_sum[t];
392-
right_sum[t] = sibling_sum;
393-
} else {
394-
// backward pass, node_sum is the right sum
395-
right_sum[t] = node_sum[t];
396-
left_sum[t] = sibling_sum;
397-
}
398-
});
396+
auto node_sums = dh::ToSpan(this->node_sums_);
397+
398+
dh::LaunchN(n_targets * d_candidates.size(), ctx->CUDACtx()->Stream(),
399+
[=] XGBOOST_DEVICE(std::size_t i) {
400+
auto get_node_sum = [&](bst_node_t nidx) {
401+
return GetNodeSumImpl(node_sums, nidx, n_targets);
402+
};
403+
auto nidx_in_set = i / n_targets;
404+
auto t = i % n_targets;
405+
406+
auto const &candidate = d_candidates[nidx_in_set];
407+
auto const &best_split = candidate.split;
408+
409+
auto parent_sum = get_node_sum(candidate.nidx);
410+
// The child sum is a pointer to the scan buffer in this evaluator. Copy
411+
// the data into the node sum buffer before the next evaluation call.
412+
auto node_sum = best_split.child_sum;
413+
auto left_sum = get_node_sum(mt_tree.LeftChild(candidate.nidx));
414+
auto right_sum = get_node_sum(mt_tree.RightChild(candidate.nidx));
415+
416+
auto sibling_sum = parent_sum[t] - node_sum[t];
417+
if (best_split.dir == kRightDir) {
418+
// forward pass, node_sum is the left sum
419+
left_sum[t] = node_sum[t];
420+
right_sum[t] = sibling_sum;
421+
} else {
422+
// backward pass, node_sum is the right sum
423+
right_sum[t] = node_sum[t];
424+
left_sum[t] = sibling_sum;
425+
}
426+
});
399427
}
400428

401429
std::ostream &DebugPrintHistogram(std::ostream &os, common::Span<GradientPairInt64 const> node_hist,

src/tree/gpu_hist/multi_evaluate_splits.cuh

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ class MultiHistEvaluator {
2020
dh::device_vector<GradientPairInt64> node_sums_;
2121

2222
public:
23+
template <typename GradT>
24+
static XGBOOST_DEVICE common::Span<GradT> GetNodeSumImpl(common::Span<GradT> node_sums,
25+
bst_node_t nidx,
26+
bst_target_t n_targets) {
27+
auto offset = nidx * n_targets;
28+
return node_sums.subspan(offset, n_targets);
29+
}
2330
/**
2431
* @brief Run evaluation for the root node.
2532
*/
@@ -41,17 +48,16 @@ class MultiHistEvaluator {
4148
}
4249
[[nodiscard]] common::Span<GradientPairInt64> GetNodeSum(bst_node_t nidx,
4350
bst_target_t n_targets) {
44-
auto offset = nidx * n_targets;
45-
return dh::ToSpan(this->node_sums_).subspan(offset, n_targets);
51+
return GetNodeSumImpl(dh::ToSpan(this->node_sums_), nidx, n_targets);
4652
}
4753
[[nodiscard]] common::Span<GradientPairInt64 const> GetNodeSum(bst_node_t nidx,
4854
bst_target_t n_targets) const {
49-
auto offset = nidx * n_targets;
50-
return dh::ToSpan(this->node_sums_).subspan(offset, n_targets);
55+
return GetNodeSumImpl(dh::ToSpan(this->node_sums_), nidx, n_targets);
5156
}
5257

5358
// Track the child gradient sum.
54-
void ApplyTreeSplit(Context const *ctx, RegTree const *p_tree, MultiExpandEntry const &candidate);
59+
void ApplyTreeSplit(Context const *ctx, RegTree const *p_tree,
60+
common::Span<MultiExpandEntry const> d_candidates, bst_target_t n_targets);
5561
};
5662

5763
std::ostream &DebugPrintHistogram(std::ostream &os, common::Span<GradientPairInt64 const> node_hist,

src/tree/updater_gpu_hist.cuh

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -187,28 +187,31 @@ class MultiTargetHistMaker {
187187
this->param_.max_bin,
188188
param};
189189
auto entry = this->evaluator_.EvaluateSingleSplit(ctx_, input, shared_inputs);
190-
191-
// TODO(jiamingy): Support learning rate.
192190
p_tree->SetRoot(linalg::MakeVec(this->ctx_->Device(), entry.base_weight));
193191

194192
return entry;
195193
}
196194

197-
void ApplySplit(MultiExpandEntry const& candidate, RegTree* p_tree) {
198-
// TODO(jiamingy): Support learning rate.
195+
void ApplySplit(std::vector<MultiExpandEntry> const& h_candidates, RegTree* p_tree) {
196+
CHECK(!h_candidates.empty());
197+
auto n_targets = h_candidates.front().base_weight.size();
198+
199199
// TODO(jiamingy): Avoid device to host copies.
200-
std::vector<float> h_base_weight(candidate.base_weight.size());
201-
std::vector<float> h_left_weight(candidate.left_weight.size());
202-
std::vector<float> h_right_weight(candidate.right_weight.size());
203-
dh::CopyDeviceSpanToVector(&h_base_weight, candidate.base_weight);
204-
dh::CopyDeviceSpanToVector(&h_left_weight, candidate.left_weight);
205-
dh::CopyDeviceSpanToVector(&h_right_weight, candidate.right_weight);
206-
207-
p_tree->ExpandNode(candidate.nidx, candidate.split.findex, candidate.split.fvalue,
208-
candidate.split.dir == kLeftDir, linalg::MakeVec(h_base_weight),
209-
linalg::MakeVec(h_left_weight), linalg::MakeVec(h_right_weight));
210-
211-
this->evaluator_.ApplyTreeSplit(this->ctx_, p_tree, candidate);
200+
for (auto const& candidate : h_candidates) {
201+
std::vector<float> h_base_weight(candidate.base_weight.size());
202+
std::vector<float> h_left_weight(candidate.left_weight.size());
203+
std::vector<float> h_right_weight(candidate.right_weight.size());
204+
dh::CopyDeviceSpanToVector(&h_base_weight, candidate.base_weight);
205+
dh::CopyDeviceSpanToVector(&h_left_weight, candidate.left_weight);
206+
dh::CopyDeviceSpanToVector(&h_right_weight, candidate.right_weight);
207+
208+
p_tree->ExpandNode(candidate.nidx, candidate.split.findex, candidate.split.fvalue,
209+
candidate.split.dir == kLeftDir, linalg::MakeVec(h_base_weight),
210+
linalg::MakeVec(h_left_weight), linalg::MakeVec(h_right_weight));
211+
}
212+
213+
dh::device_vector<MultiExpandEntry> candidates{h_candidates};
214+
this->evaluator_.ApplyTreeSplit(this->ctx_, p_tree, dh::ToSpan(candidates), n_targets);
212215
}
213216
/**
214217
* @brief Calculate the leaf weight based on the node sum for each leaf.
@@ -464,7 +467,7 @@ class MultiTargetHistMaker {
464467

465468
void GrowTree(linalg::Matrix<GradientPair>* split_gpair, DMatrix* p_fmat, ObjInfo const*,
466469
RegTree* p_tree) {
467-
if (this->param_.learning_rate - 1.0 != 0.0) {
470+
if (!this->hist_param_->debug_synchronize) {
468471
LOG(FATAL) << "GPU" << MTNotImplemented();
469472
}
470473
Driver<MultiExpandEntry> driver{param_, kMaxNodeBatchSize};
@@ -475,10 +478,7 @@ class MultiTargetHistMaker {
475478
// The set of leaves that can be expanded asynchronously
476479
auto expand_set = driver.Pop();
477480
while (!expand_set.empty()) {
478-
for (auto& candidate : expand_set) {
479-
this->ApplySplit(candidate, p_tree);
480-
}
481-
481+
this->ApplySplit(expand_set, p_tree);
482482
// Get the candidates we are allowed to expand further
483483
// e.g. We do not bother further processing nodes whose children are beyond max depth
484484
std::vector<MultiExpandEntry> valid_candidates;

0 commit comments

Comments
 (0)