Skip to content

Commit b5201b1

Browse files
authored
[mt] Implement partitioning for GPU. (#11789)
1 parent f8f2705 commit b5201b1

File tree

10 files changed

+455
-141
lines changed

10 files changed

+455
-141
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Tests for multi-target training."""
2+
3+
from typing import Optional
4+
5+
from sklearn.datasets import make_classification, make_multilabel_classification
6+
7+
import xgboost.testing as tm
8+
9+
from ..sklearn import XGBClassifier
10+
from .updater import ResetStrategy
11+
from .utils import Device
12+
13+
14+
def run_multiclass(device: Device, learning_rate: Optional[float]) -> None:
15+
"""Use vector leaf for multi-class models."""
16+
X, y = make_classification(128, n_features=12, n_informative=10, n_classes=4)
17+
clf = XGBClassifier(
18+
multi_strategy="multi_output_tree",
19+
callbacks=[ResetStrategy()],
20+
n_estimators=10,
21+
device=device,
22+
learning_rate=learning_rate,
23+
)
24+
clf.fit(X, y, eval_set=[(X, y)])
25+
assert clf.objective == "multi:softprob"
26+
assert tm.non_increasing(clf.evals_result()["validation_0"]["mlogloss"])
27+
28+
proba = clf.predict_proba(X)
29+
assert proba.shape == (y.shape[0], 4)
30+
31+
32+
def run_multilabel(device: Device, learning_rate: Optional[float]) -> None:
33+
"""Use vector leaf for multi-label classification models."""
34+
X, y = make_multilabel_classification(128)
35+
clf = XGBClassifier(
36+
multi_strategy="multi_output_tree",
37+
callbacks=[ResetStrategy()],
38+
n_estimators=10,
39+
device=device,
40+
learning_rate=learning_rate,
41+
)
42+
clf.fit(X, y, eval_set=[(X, y)])
43+
assert clf.objective == "binary:logistic"
44+
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])
45+
46+
proba = clf.predict_proba(X)
47+
assert proba.shape == y.shape

src/gbm/gbtree.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
216216
CHECK(tparam_.tree_method == TreeMethod::kHist || tparam_.tree_method == TreeMethod::kAuto)
217217
<< "Only the hist tree method is supported for building multi-target trees with vector "
218218
"leaf.";
219-
CHECK(ctx_->IsCPU()) << "GPU is not yet supported for vector leaf.";
220219
}
221220

222221
TreesOneIter new_trees;

src/tree/gpu_hist/evaluate_splits.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class GPUHistEvaluator {
208208
struct MultiEvaluateSplitInputs {
209209
bst_node_t nidx;
210210
bst_node_t depth;
211-
common::Span<GradientPairInt64> parent_sum;
211+
common::Span<GradientPairInt64 const> parent_sum;
212212
common::Span<const GradientPairInt64> histogram;
213213
};
214214

src/tree/gpu_hist/multi_evaluate_splits.cu

Lines changed: 144 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515

1616
namespace xgboost::tree::cuda_impl {
1717
namespace {
18-
__device__ bst_bin_t RevBinIdx(bst_bin_t gidx_end, bst_bin_t bin_idx) {
19-
return gidx_end - bin_idx - 1;
18+
/**
19+
* @brief Calculate the gradient index for the reverse pass
20+
*
21+
* @note All inputs are global across features.
22+
*/
23+
__device__ bst_bin_t RevBinIdx(bst_bin_t gidx_begin, bst_bin_t gidx_end, bst_bin_t bin_idx) {
24+
return gidx_begin + (gidx_end - bin_idx - 1);
2025
}
2126

2227
// Scan the histogram in 2 dim for all nodes
@@ -60,7 +65,7 @@ struct ScanHistogramAgent {
6065
__device__ void Backward(common::Span<GradientPairInt64 const> node_histogram,
6166
common::Span<GradientPairInt64> scan_result, bst_target_t t) {
6267
this->ScanFeature(node_histogram, scan_result, t,
63-
[&](bst_bin_t bin_idx) { return RevBinIdx(gidx_end, bin_idx); });
68+
[&](bst_bin_t bin_idx) { return RevBinIdx(gidx_begin, gidx_end, bin_idx); });
6469
}
6570
};
6671
} // namespace
@@ -112,7 +117,7 @@ struct EvaluateSplitAgent {
112117
template <std::int32_t d_step>
113118
__device__ void Numerical(MultiEvaluateSplitInputs const &node,
114119
MultiEvaluateSplitSharedInputs const &shared,
115-
common::Span<GradientPairInt64 const> f_scan,
120+
common::Span<GradientPairInt64 const> node_scan,
116121
MultiSplitCandidate *best_split) {
117122
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
118123
// Calculate split gain for each bin
@@ -130,7 +135,7 @@ struct EvaluateSplitAgent {
130135
double gain = thread_active ? 0 : kNullGain;
131136

132137
if (thread_active) {
133-
auto scan_bin = f_scan.subspan(bin_idx * n_targets, n_targets);
138+
auto scan_bin = node_scan.subspan(bin_idx * n_targets, n_targets);
134139
for (bst_target_t t = 0; t < n_targets; ++t) {
135140
auto pg = roundings[t].ToFloatingPoint(node.parent_sum[t]);
136141
// left
@@ -155,7 +160,7 @@ struct EvaluateSplitAgent {
155160
// Update
156161
bst_bin_t split_gidx = bin_idx;
157162
if (d_step == -1) {
158-
split_gidx = RevBinIdx(gidx_end, bin_idx);
163+
split_gidx = RevBinIdx(gidx_begin, gidx_end, bin_idx);
159164
}
160165
float min_fvalue = shared.min_values[fidx];
161166
float fvalue;
@@ -168,7 +173,7 @@ struct EvaluateSplitAgent {
168173
fvalue = shared.feature_values[split_gidx - 1];
169174
}
170175
}
171-
auto scan_bin = f_scan.subspan(bin_idx * n_targets, n_targets);
176+
auto scan_bin = node_scan.subspan(bin_idx * n_targets, n_targets);
172177
// Missing values go to right in the forward pass, go to left in the backward pass.
173178
best_split->Update(gain, d_step == 1 ? kRightDir : kLeftDir, fvalue, fidx, scan_bin, false,
174179
shared.param, shared.roundings);
@@ -190,80 +195,120 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
190195
using AgentT = EvaluateSplitAgent<kBlockThreads>;
191196
__shared__ typename AgentT::TempStorage temp_storage;
192197

193-
auto fidx = blockIdx.x;
194-
EvaluateSplitAgent<kBlockThreads> agent{&temp_storage, blockIdx.x};
198+
const auto nidx = blockIdx.x / shared.Features();
199+
bst_feature_t fidx = blockIdx.x % shared.Features();
200+
AgentT agent{&temp_storage, fidx};
195201

196202
auto n_targets = shared.Targets();
197203
// The number of bins in a feature
198204
auto f_hist_size =
199205
(shared.feature_segments[fidx + 1] - shared.feature_segments[fidx]) * n_targets;
200-
// TODO(jiamingy): Support more than a single node
206+
207+
auto candidate_idx = nidx * shared.Features() + fidx;
201208

202209
if (shared.one_pass != MultiEvaluateSplitSharedInputs::kBackward) {
203-
auto forward = bin_scans[0].subspan(0, nodes[0].histogram.size());
204-
auto f_scan = forward.subspan(shared.feature_segments[fidx] * n_targets, f_hist_size);
205-
agent.template Numerical<+1>(nodes[0], shared, f_scan, &out_candidates[fidx]);
210+
auto forward = bin_scans[nidx].subspan(0, nodes[nidx].histogram.size());
211+
agent.template Numerical<+1>(nodes[nidx], shared, forward, &out_candidates[candidate_idx]);
206212
}
207213
if (shared.one_pass != MultiEvaluateSplitSharedInputs::kForward) {
208-
auto backward = bin_scans[0].subspan(nodes[0].histogram.size(), nodes[0].histogram.size());
209-
auto f_scan = backward.subspan(shared.feature_segments[fidx] * n_targets, f_hist_size);
210-
agent.template Numerical<-1>(nodes[0], shared, f_scan, &out_candidates[fidx]);
214+
auto backward =
215+
bin_scans[nidx].subspan(nodes[nidx].histogram.size(), nodes[nidx].histogram.size());
216+
agent.template Numerical<-1>(nodes[nidx], shared, backward, &out_candidates[candidate_idx]);
211217
}
212218
}
213219

214220
[[nodiscard]] MultiExpandEntry MultiHistEvaluator::EvaluateSingleSplit(
215-
Context const *ctx, MultiEvaluateSplitInputs input,
216-
MultiEvaluateSplitSharedInputs shared_inputs) {
221+
Context const *ctx, MultiEvaluateSplitInputs const &input,
222+
MultiEvaluateSplitSharedInputs const &shared_inputs) {
223+
dh::device_vector<MultiEvaluateSplitInputs> inputs{input};
224+
dh::device_vector<MultiExpandEntry> outputs(1);
225+
226+
this->EvaluateSplits(ctx, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(outputs));
227+
return outputs[0];
228+
}
229+
230+
void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
231+
common::Span<MultiEvaluateSplitInputs const> d_inputs,
232+
MultiEvaluateSplitSharedInputs const &shared_inputs,
233+
common::Span<MultiExpandEntry> out_splits) {
217234
auto n_targets = shared_inputs.Targets();
218235
CHECK_GE(n_targets, 2);
219236
auto n_bins_per_feat_tar = shared_inputs.n_bins_per_feat_tar;
220237
CHECK_GE(n_bins_per_feat_tar, 1);
221238
auto n_features = shared_inputs.Features();
222239
CHECK_GE(n_features, 1);
223240

224-
dh::device_vector<MultiEvaluateSplitInputs> inputs{input};
241+
std::uint32_t n_nodes = d_inputs.size();
242+
CHECK_EQ(n_nodes, out_splits.size());
243+
244+
if (n_nodes == 0) {
245+
return;
246+
}
247+
248+
// Calculate total scan buffer size needed for all nodes
249+
auto node_hist_size = n_targets * n_features * n_bins_per_feat_tar;
250+
std::size_t total_hist_size = node_hist_size * n_nodes;
225251

226252
// Scan the histograms. One for forward and the other for backward.
227-
this->scan_buffer_.resize(input.histogram.size() * 2);
253+
this->scan_buffer_.resize(total_hist_size * 2);
228254
thrust::fill(ctx->CUDACtx()->CTP(), this->scan_buffer_.begin(), this->scan_buffer_.end(),
229255
GradientPairInt64{});
230-
dh::device_vector<common::Span<GradientPairInt64>> scans{dh::ToSpan(this->scan_buffer_)};
231-
std::uint32_t n_nodes = 1;
256+
257+
// Create spans for each node's scan results
258+
dh::device_vector<common::Span<GradientPairInt64>> scans(n_nodes);
259+
for (std::size_t nidx_in_set = 0; nidx_in_set < n_nodes; ++nidx_in_set) {
260+
scans[nidx_in_set] = dh::ToSpan(this->scan_buffer_)
261+
.subspan(nidx_in_set * node_hist_size * 2, node_hist_size * 2);
262+
}
263+
264+
// Launch histogram scan kernel
232265
dim3 grid{n_nodes, n_features, n_targets};
233266
std::uint32_t constexpr kBlockThreads = 32;
234267
dh::LaunchKernel{grid, kBlockThreads}( // NOLINT
235-
ScanHistogramKernel<kBlockThreads>, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(scans));
268+
ScanHistogramKernel<kBlockThreads>, d_inputs, shared_inputs, dh::ToSpan(scans));
236269

237-
dh::device_vector<MultiSplitCandidate> d_splits(n_features);
238-
dh::LaunchKernel{n_features, kBlockThreads, 0, ctx->CUDACtx()->Stream()}( // NOLINT
239-
EvaluateSplitsKernel<kBlockThreads>, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(scans),
270+
// Launch split evaluation kernel
271+
dh::device_vector<MultiSplitCandidate> d_splits(n_nodes * n_features);
272+
dh::LaunchKernel{n_nodes * n_features, kBlockThreads, 0, ctx->CUDACtx()->Stream()}( // NOLINT
273+
EvaluateSplitsKernel<kBlockThreads>, d_inputs, shared_inputs, dh::ToSpan(scans),
240274
dh::ToSpan(d_splits));
241275

242-
auto best_split = thrust::reduce(
243-
ctx->CUDACtx()->CTP(), d_splits.cbegin(), d_splits.cend(), MultiSplitCandidate{},
244-
[] XGBOOST_DEVICE(MultiSplitCandidate const &lhs, MultiSplitCandidate const &rhs)
245-
-> MultiSplitCandidate { return lhs.loss_chg > rhs.loss_chg ? lhs : rhs; });
276+
// Find best split for each node
277+
this->weights_.resize(n_nodes * n_targets * 3);
278+
auto d_weights = dh::ToSpan(this->weights_);
246279

247-
if (best_split.node_sum.empty()) {
248-
return {};
249-
}
280+
dh::CachingDeviceUVector<float> d_parent_gains(n_nodes);
281+
dh::CachingDeviceUVector<std::int32_t> sum_zeros(n_nodes * 2);
250282

251-
// Calculate leaf weights from gradient sum
252-
this->weights_.resize(n_targets * 3);
253-
auto d_weights = dh::ToSpan(this->weights_);
254-
auto base_weight = d_weights.subspan(0, n_targets);
255-
auto left_weight = d_weights.subspan(n_targets, n_targets);
256-
auto right_weight = d_weights.subspan(n_targets * 2, n_targets);
283+
auto s_parent_gains = dh::ToSpan(d_parent_gains);
284+
auto s_sum_zeros = dh::ToSpan(sum_zeros);
285+
auto s_d_splits = dh::ToSpan(d_splits);
257286

258-
dh::CachingDeviceUVector<float> d_parent_gain(1);
259-
dh::CachingDeviceUVector<std::int32_t> sum_zero(2);
287+
// Process results for each node
288+
dh::LaunchN(n_nodes, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t nidx_in_set) {
289+
auto input = d_inputs[nidx_in_set];
260290

261-
auto s_pg = dh::ToSpan(d_parent_gain);
262-
auto s_sum_zero = dh::ToSpan(sum_zero);
291+
// Find best split among all features for this node
292+
MultiSplitCandidate best_split{};
293+
for (bst_feature_t f = 0; f < n_features; ++f) {
294+
auto candidate = s_d_splits[nidx_in_set * n_features + f];
295+
if (candidate.loss_chg > best_split.loss_chg) {
296+
best_split = candidate;
297+
}
298+
}
299+
300+
if (best_split.node_sum.empty()) {
301+
// Invalid split
302+
out_splits[nidx_in_set] = {};
303+
return;
304+
}
305+
306+
// Calculate weights for this node
307+
auto base_weight = d_weights.subspan(nidx_in_set * n_targets * 3, n_targets);
308+
auto left_weight = d_weights.subspan(nidx_in_set * n_targets * 3 + n_targets, n_targets);
309+
auto right_weight = d_weights.subspan(nidx_in_set * n_targets * 3 + n_targets * 2, n_targets);
263310

264-
dh::LaunchN(inputs.size(), ctx->CUDACtx()->Stream(), [=] __device__(std::size_t i) {
265311
auto d_roundings = shared_inputs.roundings;
266-
// the data inside the split candidates references the scan result.
267312
auto node_sum = best_split.node_sum;
268313

269314
float parent_gain = 0;
@@ -276,15 +321,15 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
276321
base_weight[t] = CalcWeight(shared_inputs.param, g.GetGrad(), g.GetHess());
277322
parent_gain += -base_weight[t] * ThresholdL1(g.GetGrad(), shared_inputs.param.reg_alpha);
278323
}
279-
s_pg[0] = parent_gain;
324+
s_parent_gains[nidx_in_set] = parent_gain;
280325

281326
bool l = true, r = true;
282327
for (bst_target_t t = 0; t < n_targets; ++t) {
283328
auto quantizer = d_roundings[t];
284329
auto sibling_sum = input.parent_sum[t] - node_sum[t];
285330

286-
l = l && (node_sum[t].GetQuantisedHess() - .0 == .0);
287-
r = r && (sibling_sum.GetQuantisedHess() - .0 == .0);
331+
l = l && (node_sum[t].GetQuantisedHess() == 0);
332+
r = r && (sibling_sum.GetQuantisedHess() == 0);
288333

289334
if (best_split.dir == kRightDir) {
290335
// forward pass, node_sum is the left sum
@@ -299,43 +344,72 @@ __global__ __launch_bounds__(kBlockThreads) void EvaluateSplitsKernel(
299344
auto lg = quantizer.ToFloatingPoint(sibling_sum);
300345
left_weight[t] = CalcWeight(shared_inputs.param, lg.GetGrad(), lg.GetHess());
301346
}
347+
}
348+
349+
s_sum_zeros[nidx_in_set * 2] = l;
350+
s_sum_zeros[nidx_in_set * 2 + 1] = r;
351+
352+
// Set up the output entry
353+
out_splits[nidx_in_set] = {input.nidx, input.depth, best_split,
354+
base_weight, left_weight, right_weight};
355+
out_splits[nidx_in_set].split.loss_chg -= parent_gain;
302356

303-
s_sum_zero[0] = l;
304-
s_sum_zero[1] = r;
357+
if (l || r) {
358+
out_splits[nidx_in_set] = {};
305359
}
306360
});
307-
// Copy the result back to the host.
308-
float parent_gain = 0;
309-
dh::safe_cuda(cudaMemcpyAsync(&parent_gain, d_parent_gain.data(), sizeof(parent_gain),
310-
cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
311-
best_split.loss_chg -= parent_gain;
312-
313-
std::vector<std::int32_t> h_sum_zero(s_sum_zero.size());
314-
dh::safe_cuda(cudaMemcpyAsync(h_sum_zero.data(), s_sum_zero.data(), s_sum_zero.size_bytes(),
315-
cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
316-
if (h_sum_zero[0] || h_sum_zero[1]) {
317-
return {};
318-
}
361+
}
362+
363+
void MultiHistEvaluator::ApplyTreeSplit(Context const *ctx, RegTree const *p_tree,
364+
MultiExpandEntry const &candidate) {
365+
auto n_targets = p_tree->NumTargets();
366+
367+
auto left_child = p_tree->LeftChild(candidate.nidx);
368+
auto right_child = p_tree->RightChild(candidate.nidx);
369+
bst_node_t max_node = std::max(left_child, right_child);
370+
this->AllocNodeSum(max_node, n_targets);
371+
372+
auto parent_sum = this->GetNodeSum(candidate.nidx, n_targets);
319373

320-
MultiExpandEntry entry{input.nidx, input.depth, best_split,
321-
base_weight, left_weight, right_weight};
322-
return entry;
374+
auto left_sum = this->GetNodeSum(left_child, n_targets);
375+
auto right_sum = this->GetNodeSum(right_child, n_targets);
376+
377+
// Calculate node sums
378+
// TODO(jiamingy): We need to batch the targets and nodes
379+
auto best_split = candidate.split;
380+
auto node_sum = best_split.node_sum;
381+
dh::LaunchN(1, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t) {
382+
for (bst_target_t t = 0; t < n_targets; ++t) {
383+
auto sibling_sum = parent_sum[t] - node_sum[t];
384+
if (best_split.dir == kRightDir) {
385+
// forward pass, node_sum is the left sum
386+
left_sum[t] = node_sum[t];
387+
right_sum[t] = sibling_sum;
388+
} else {
389+
// backward pass, node_sum is the right sum
390+
right_sum[t] = node_sum[t];
391+
left_sum[t] = sibling_sum;
392+
}
393+
}
394+
});
323395
}
324396

325-
void DebugPrintHistogram(common::Span<GradientPairInt64 const> node_hist,
326-
common::Span<GradientQuantiser const> roundings, bst_target_t n_targets) {
397+
std::ostream &DebugPrintHistogram(std::ostream &os, common::Span<GradientPairInt64 const> node_hist,
398+
common::Span<GradientQuantiser const> roundings,
399+
bst_target_t n_targets) {
327400
std::vector<GradientQuantiser> h_roundings;
328401
thrust::copy(dh::tcbegin(roundings), dh::tcend(roundings), std::back_inserter(h_roundings));
329402
dh::CopyDeviceSpanToVector(&h_roundings, roundings);
330403

331404
std::vector<GradientPairInt64> h_node_hist(node_hist.size());
332405
dh::CopyDeviceSpanToVector(&h_node_hist, node_hist);
333406
for (bst_target_t t = 0; t < n_targets; ++t) {
334-
std::cout << "target:" << t << std::endl;
407+
os << "Target:" << t << std::endl;
335408
for (std::size_t i = t; i < h_node_hist.size() / n_targets; i += n_targets) {
336-
std::cout << h_roundings[t].ToFloatingPoint(h_node_hist[i]) << ", ";
409+
os << h_roundings[t].ToFloatingPoint(h_node_hist[i]) << ", ";
337410
}
338-
std::cout << std::endl;
411+
os << std::endl;
339412
}
413+
return os;
340414
}
341415
} // namespace xgboost::tree::cuda_impl

0 commit comments

Comments
 (0)