Skip to content

Commit 41d6731

Browse files
authored
[mt] Cleanup partitioners. (#11820)
- Share the initialization code between sc tree and mt tree updaters. - Share the temporary buffer between partitioners. - Implement external memory support for reduced gradient.
1 parent 62a3a4b commit 41d6731

File tree

13 files changed

+194
-138
lines changed

13 files changed

+194
-138
lines changed

python-package/xgboost/testing/multi_target.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .._typing import ArrayLike
1616
from ..compat import import_cupy
17-
from ..core import Booster, DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix
17+
from ..core import Booster, DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix, build_info
1818
from ..objective import Objective, TreeObjective
1919
from ..sklearn import XGBClassifier
2020
from ..training import train
@@ -172,6 +172,15 @@ def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
172172
n_rounds = 8
173173
n_targets = 3
174174
intercept = [0.5] * n_targets
175+
176+
params = {
177+
"device": device,
178+
"multi_strategy": "multi_output_tree",
179+
"learning_rate": 1.0,
180+
"base_score": intercept,
181+
"debug_synchronize": True,
182+
}
183+
175184
Xs = []
176185
ys = []
177186
for i in range(n_batches):
@@ -185,12 +194,7 @@ def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
185194

186195
evals_result_0: Dict[str, Dict] = {}
187196
booster_0 = train(
188-
{
189-
"device": device,
190-
"multi_strategy": "multi_output_tree",
191-
"learning_rate": 1.0,
192-
"base_score": intercept,
193-
},
197+
params,
194198
Xy,
195199
num_boost_round=n_rounds,
196200
evals=[(Xy, "Train")],
@@ -201,12 +205,7 @@ def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
201205
Xy = QuantileDMatrix(it)
202206
evals_result_1: Dict[str, Dict] = {}
203207
booster_1 = train(
204-
{
205-
"device": device,
206-
"multi_strategy": "multi_output_tree",
207-
"learning_rate": 1.0,
208-
"base_score": intercept,
209-
},
208+
params,
210209
Xy,
211210
num_boost_round=n_rounds,
212211
evals=[(Xy, "Train")],
@@ -219,18 +218,23 @@ def run_with_iter(device: Device) -> None: # pylint: disable=too-many-locals
219218
X, _, _ = it.as_arrays()
220219
assert_allclose(device, booster_0.inplace_predict(X), booster_1.inplace_predict(X))
221220

222-
it = IteratorForTest(Xs, ys, None, cache="cache", on_host=True)
221+
v = build_info()["THRUST_VERSION"]
222+
if v[0] < 3:
223+
pytest.xfail("CCCL version too old.")
224+
225+
it = IteratorForTest(
226+
Xs,
227+
ys,
228+
None,
229+
cache="cache",
230+
on_host=True,
231+
min_cache_page_bytes=X.shape[0] // n_batches * X.shape[1],
232+
)
223233
Xy = ExtMemQuantileDMatrix(it, cache_host_ratio=1.0)
224234

225235
evals_result_2: Dict[str, Dict] = {}
226236
booster_2 = train(
227-
{
228-
"device": device,
229-
"multi_strategy": "multi_output_tree",
230-
"learning_rate": 1.0,
231-
"base_score": intercept,
232-
"debug_synchronize": True,
233-
},
237+
params,
234238
Xy,
235239
evals=[(Xy, "Train")],
236240
obj=LsObj0(),

src/common/device_vector.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ template <typename T>
411411
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT
412412

413413
/**
414-
* @brief Container class that doesn't initialize the data when RMM is used.
414+
* @brief Container class that doesn't initialize the data.
415415
*/
416416
template <typename T, bool is_caching>
417417
class DeviceUVectorImpl {

src/tree/gpu_hist/expand_entry.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& e) {
1616
<< "depth: " << e.depth << "\n"
1717
<< "loss: " << e.split.loss_chg << "\n";
1818

19-
std::vector<GradientPairInt64> h_node_sum(e.split.node_sum.size());
20-
dh::CopyDeviceSpanToVector(&h_node_sum, e.split.node_sum);
19+
std::vector<GradientPairInt64> h_node_sum(e.split.child_sum.size());
20+
dh::CopyDeviceSpanToVector(&h_node_sum, e.split.child_sum);
2121

2222
auto print_span = [&](auto const& span) {
2323
using T = typename common::GetValueT<decltype(span)>::value_type;
@@ -38,7 +38,7 @@ std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& e) {
3838
} else {
3939
os << "right_sum: ";
4040
}
41-
print_span(e.split.node_sum);
41+
print_span(e.split.child_sum);
4242

4343
os << "base_weight: ";
4444
print_span(e.base_weight);

src/tree/gpu_hist/leaf_sum.cu

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
/**
22
* Copyright 2025, XGBoost contributors
33
*/
4-
#include <cstddef> // for size_t
5-
#include <vector> // for vector
4+
#include <thrust/scan.h> // for inclusive_scan
5+
#include <thrust/version.h> // for THRUST_MAJOR_VERSION
66

7-
#include "../../common/linalg_op.cuh" // for tbegin
8-
#include "../updater_gpu_common.cuh" // for GPUTrainingParam
7+
#include <cstddef> // for size_t
8+
#include <cstdint> // for int32_t
9+
#include <cub/device/device_segmented_reduce.cuh> // for DeviceSegmentedReduce
10+
#include <vector> // for vector
11+
12+
#include "../updater_gpu_common.cuh" // for GPUTrainingParam
913
#include "leaf_sum.cuh"
1014
#include "quantiser.cuh" // for GradientQuantiser
1115
#include "row_partitioner.cuh" // for RowIndexT, LeafInfo
@@ -14,6 +18,12 @@
1418
#include "xgboost/linalg.h" // for MatrixView
1519
#include "xgboost/span.h" // for Span
1620

21+
#if THRUST_MAJOR_VERSION >= 3
22+
#include <thrust/iterator/tabulate_output_iterator.h> // for make_tabulate_output_iterator
23+
#else
24+
#include "../../common/linalg_op.cuh" // for tbegin
25+
#endif
26+
1727
namespace xgboost::tree::cuda_impl {
1828
void LeafGradSum(Context const* ctx, std::vector<LeafInfo> const& h_leaves,
1929
common::Span<GradientQuantiser const> roundings,
@@ -50,14 +60,22 @@ void LeafGradSum(Context const* ctx, std::vector<LeafInfo> const& h_leaves,
5060
auto g = grad(sorted_ridx[j], t);
5161
return roundings[t].ToFixedPoint(g);
5262
});
63+
// Use an output iterator to implement running sum.
64+
#if THRUST_MAJOR_VERSION >= 3
65+
auto out_it = thrust::make_tabulate_output_iterator(
66+
[=] XGBOOST_DEVICE(std::int32_t idx, GradientPairInt64 v) mutable { out_t(idx) += v; });
67+
#else
68+
auto out_it = linalg::tbegin(out_t);
69+
#endif
70+
5371
std::size_t n_bytes = 0;
54-
dh::safe_cuda(cub::DeviceSegmentedReduce::Sum(nullptr, n_bytes, it, linalg::tbegin(out_t),
55-
h_leaves.size(), indptr.data(), indptr.data() + 1,
72+
dh::safe_cuda(cub::DeviceSegmentedReduce::Sum(nullptr, n_bytes, it, out_it, h_leaves.size(),
73+
indptr.data(), indptr.data() + 1,
5674
ctx->CUDACtx()->Stream()));
5775
dh::TemporaryArray<char> alloc(n_bytes);
58-
dh::safe_cuda(cub::DeviceSegmentedReduce::Sum(
59-
alloc.data().get(), n_bytes, it, linalg::tbegin(out_t), h_leaves.size(), indptr.data(),
60-
indptr.data() + 1, ctx->CUDACtx()->Stream()));
76+
dh::safe_cuda(cub::DeviceSegmentedReduce::Sum(alloc.data().get(), n_bytes, it, out_it,
77+
h_leaves.size(), indptr.data(), indptr.data() + 1,
78+
ctx->CUDACtx()->Stream()));
6179
}
6280
}
6381

@@ -66,7 +84,6 @@ void LeafWeight(Context const* ctx, GPUTrainingParam const& param,
6684
linalg::MatrixView<GradientPairInt64 const> grad_sum,
6785
linalg::MatrixView<float> out_weights) {
6886
CHECK(grad_sum.Contiguous());
69-
auto s_grad_sum = grad_sum.Values();
7087
dh::LaunchN(grad_sum.Size(), ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t i) mutable {
7188
auto [nidx_in_set, t] = linalg::UnravelIndex(i, grad_sum.Shape());
7289
auto g = roundings[t].ToFloatingPoint(grad_sum(nidx_in_set, t));

src/tree/gpu_hist/multi_evaluate_splits.cu

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,7 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
280280
auto d_weights = dh::ToSpan(this->weights_);
281281

282282
dh::CachingDeviceUVector<float> d_parent_gains(n_nodes);
283-
dh::CachingDeviceUVector<std::int32_t> sum_zeros(n_nodes * 2);
284-
285283
auto s_parent_gains = dh::ToSpan(d_parent_gains);
286-
auto s_sum_zeros = dh::ToSpan(sum_zeros);
287284
auto s_d_splits = dh::ToSpan(d_splits);
288285

289286
// Process results for each node
@@ -304,7 +301,7 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
304301
dh::LaunchN(n_nodes, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t nidx_in_set) {
305302
auto input = d_inputs[nidx_in_set];
306303
MultiSplitCandidate best_split = d_best_splits[nidx_in_set];
307-
if (best_split.node_sum.empty()) {
304+
if (best_split.child_sum.empty()) {
308305
// Invalid split
309306
out_splits[nidx_in_set] = {};
310307
return;
@@ -316,7 +313,7 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
316313
auto right_weight = d_weights.subspan(nidx_in_set * n_targets * 3 + n_targets * 2, n_targets);
317314

318315
auto d_roundings = shared_inputs.roundings;
319-
auto node_sum = best_split.node_sum;
316+
auto node_sum = best_split.child_sum;
320317

321318
float parent_gain = 0;
322319
for (bst_target_t t = 0; t < n_targets; ++t) {
@@ -353,9 +350,6 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
353350
}
354351
}
355352

356-
s_sum_zeros[nidx_in_set * 2] = l;
357-
s_sum_zeros[nidx_in_set * 2 + 1] = r;
358-
359353
// Set up the output entry
360354
out_splits[nidx_in_set] = {input.nidx, input.depth, best_split,
361355
base_weight, left_weight, right_weight};
@@ -384,7 +378,7 @@ void MultiHistEvaluator::ApplyTreeSplit(Context const *ctx, RegTree const *p_tre
384378
// TODO(jiamingy): We need to batch the nodes
385379
auto best_split = candidate.split;
386380

387-
auto node_sum = best_split.node_sum;
381+
auto node_sum = best_split.child_sum;
388382
dh::LaunchN(n_targets, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t t) {
389383
auto sibling_sum = parent_sum[t] - node_sum[t];
390384
if (best_split.dir == kRightDir) {

src/tree/gpu_hist/row_partitioner.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2024, XGBoost contributors
2+
* Copyright 2017-2025, XGBoost contributors
33
*/
44
#include <thrust/sequence.h> // for sequence
55

@@ -13,7 +13,6 @@ namespace xgboost::tree {
1313
void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid) {
1414
ridx_segments_.clear();
1515
ridx_.resize(n_samples);
16-
ridx_tmp_.resize(n_samples);
1716
tmp_.clear();
1817
n_nodes_ = 1; // Root
1918

src/tree/gpu_hist/row_partitioner.cuh

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,24 @@ void SortPositionBatch(Context const* ctx, common::Span<const PerNodeData<OpData
174174
auto ret =
175175
cub::DispatchScan<decltype(input_iterator), decltype(discard_write_iterator), IndexFlagOp,
176176
cub::NullType, std::uint64_t>::Dispatch(nullptr, n_bytes, input_iterator,
177-
discard_write_iterator,
178-
IndexFlagOp{}, cub::NullType{},
179-
static_cast<std::uint64_t>(total_rows),
180-
ctx->CUDACtx()->Stream());
177+
discard_write_iterator,
178+
IndexFlagOp{}, cub::NullType{},
179+
static_cast<std::uint64_t>(
180+
total_rows),
181+
ctx->CUDACtx()->Stream());
181182
dh::safe_cuda(ret);
182183
tmp->resize(n_bytes);
183184
}
184185
n_bytes = tmp->size();
185186
auto ret =
186187
cub::DispatchScan<decltype(input_iterator), decltype(discard_write_iterator), IndexFlagOp,
187-
cub::NullType, std::uint64_t>::Dispatch(tmp->data(), n_bytes, input_iterator,
188-
discard_write_iterator,
189-
IndexFlagOp{}, cub::NullType{},
190-
static_cast<std::uint64_t>(total_rows),
191-
ctx->CUDACtx()->Stream());
188+
cub::NullType, std::uint64_t>::Dispatch(tmp->data(), n_bytes,
189+
input_iterator,
190+
discard_write_iterator,
191+
IndexFlagOp{}, cub::NullType{},
192+
static_cast<std::uint64_t>(
193+
total_rows),
194+
ctx->CUDACtx()->Stream());
192195
dh::safe_cuda(ret);
193196

194197
constexpr int kBlockSize = 256;
@@ -272,8 +275,6 @@ class RowPartitioner {
272275
* rows idx | 3, 5, 1 | 13, 31 |
273276
*/
274277
dh::DeviceUVector<RowIndexT> ridx_;
275-
// Staging area for sorting ridx
276-
dh::DeviceUVector<RowIndexT> ridx_tmp_;
277278
dh::DeviceUVector<int8_t> tmp_;
278279
dh::PinnedMemory pinned_;
279280
dh::PinnedMemory pinned2_;
@@ -343,7 +344,8 @@ class RowPartitioner {
343344
void UpdatePositionBatch(Context const* ctx, std::vector<bst_node_t> const& nidx,
344345
std::vector<bst_node_t> const& left_nidx,
345346
std::vector<bst_node_t> const& right_nidx,
346-
std::vector<OpDataT> const& op_data, UpdatePositionOpT op) {
347+
std::vector<OpDataT> const& op_data, common::Span<RowIndexT> ridx_tmp,
348+
UpdatePositionOpT op) {
347349
if (nidx.empty()) {
348350
return;
349351
}
@@ -366,20 +368,21 @@ class RowPartitioner {
366368
auto h_counts = pinned_.GetSpan<RowIndexT>(nidx.size());
367369
// Must initialize with 0 as 0 count is not written in the kernel.
368370
dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0);
371+
CHECK_EQ(ridx_tmp.size(), this->Size());
369372

370373
// Process a sub-batch
371-
auto sub_batch_impl = [ctx, op, this](common::Span<bst_node_t const> nidx,
372-
common::Span<PerNodeData<OpDataT>> d_batch_info,
373-
common::Span<RowIndexT> d_counts) {
374+
auto sub_batch_impl = [&](common::Span<bst_node_t const> nidx,
375+
common::Span<PerNodeData<OpDataT>> d_batch_info,
376+
common::Span<RowIndexT> d_counts) {
374377
std::size_t total_rows = 0;
375378
for (bst_node_t i : nidx) {
376379
total_rows += this->ridx_segments_[i].segment.Size();
377380
}
378381

379382
// Partition the rows according to the operator
380383
SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, d_batch_info, dh::ToSpan(this->ridx_),
381-
dh::ToSpan(this->ridx_tmp_), d_counts,
382-
total_rows, op, &this->tmp_);
384+
ridx_tmp, d_counts, total_rows, op,
385+
&this->tmp_);
383386
};
384387

385388
// Divide inputs into sub-batches.
@@ -441,4 +444,59 @@ class RowPartitioner {
441444
base_ridx, d_ridx, d_out_position, op);
442445
}
443446
};
447+
448+
// Partitioner for all batches, used for external memory training.
449+
class RowPartitionerBatches {
450+
private:
451+
// Temporary buffer for sorting the samples.
452+
dh::DeviceUVector<cuda_impl::RowIndexT> ridx_tmp_;
453+
// Partitioners for each batch.
454+
std::vector<std::unique_ptr<RowPartitioner>> partitioners_;
455+
456+
public:
457+
void Reset(Context const* ctx, std::vector<bst_idx_t> const& batch_ptr) {
458+
CHECK_GE(batch_ptr.size(), 2);
459+
std::size_t n_batches = batch_ptr.size() - 1;
460+
if (partitioners_.size() != n_batches) {
461+
partitioners_.clear();
462+
}
463+
464+
bst_idx_t n_max_samples = 0;
465+
for (std::size_t k = 0; k < n_batches; ++k) {
466+
if (partitioners_.size() != n_batches) {
467+
// First run.
468+
partitioners_.emplace_back(std::make_unique<RowPartitioner>());
469+
}
470+
auto base_ridx = batch_ptr[k];
471+
auto n_samples = batch_ptr.at(k + 1) - base_ridx;
472+
partitioners_[k]->Reset(ctx, n_samples, base_ridx);
473+
CHECK_LE(n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max());
474+
n_max_samples = std::max(n_samples, n_max_samples);
475+
}
476+
this->ridx_tmp_.resize(n_max_samples);
477+
}
478+
479+
// Accessors
480+
[[nodiscard]] decltype(auto) operator[](std::size_t i) { return partitioners_[i]; }
481+
decltype(auto) At(std::size_t i) { return partitioners_.at(i); }
482+
[[nodiscard]] std::size_t Size() const { return this->partitioners_.size(); }
483+
decltype(auto) cbegin() const { return this->partitioners_.cbegin(); } // NOLINT
484+
decltype(auto) cend() const { return this->partitioners_.cend(); } // NOLINT
485+
decltype(auto) begin() const { return this->partitioners_.cbegin(); } // NOLINT
486+
decltype(auto) end() const { return this->partitioners_.cend(); } // NOLINT
487+
488+
[[nodiscard]] decltype(auto) Front() { return this->partitioners_.front(); }
489+
[[nodiscard]] bool Empty() const { return this->partitioners_.empty(); }
490+
491+
template <typename UpdatePositionOpT, typename OpDataT>
492+
void UpdatePositionBatch(Context const* ctx, std::int32_t batch_idx,
493+
std::vector<bst_node_t> const& nidx,
494+
std::vector<bst_node_t> const& left_nidx,
495+
std::vector<bst_node_t> const& right_nidx,
496+
std::vector<OpDataT> const& op_data, UpdatePositionOpT op) {
497+
auto& part = this->At(batch_idx);
498+
auto ridx_tmp = dh::ToSpan(this->ridx_tmp_).subspan(0, part->Size());
499+
part->UpdatePositionBatch(ctx, nidx, left_nidx, right_nidx, op_data, ridx_tmp, op);
500+
}
501+
};
444502
}; // namespace xgboost::tree

0 commit comments

Comments
 (0)