Skip to content

Commit 425d136

Browse files
authored
Replace the device model. (#11752)
- Support leaf, QDM, inplace prediction for multi-target. - Replace device model with tree-internal storage. - Cleanup predict kernel.
1 parent 70b9b10 commit 425d136

File tree

14 files changed

+353
-503
lines changed

14 files changed

+353
-503
lines changed

include/xgboost/tree_model.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,18 @@ class RegTree : public Model {
262262
Node& operator[](bst_node_t nidx) { return nodes_.HostVector()[nidx]; }
263263

264264
public:
265-
/*! \brief get const reference to nodes */
266-
[[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_.ConstHostVector(); }
265+
/** @brief Get const reference to nodes */
266+
[[nodiscard]] common::Span<Node const> GetNodes(DeviceOrd device) const {
267+
CHECK(!this->IsMultiTarget());
268+
return device.IsCPU() ? nodes_.ConstHostSpan()
269+
: (nodes_.SetDevice(device), nodes_.ConstDeviceSpan());
270+
}
267271

268-
/*! \brief get const reference to stats */
269-
[[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const {
270-
return stats_.ConstHostVector();
272+
/** @brief Get const reference to stats */
273+
[[nodiscard]] common::Span<RTreeNodeStat const> GetStats(DeviceOrd device) const {
274+
CHECK(!this->IsMultiTarget());
275+
return device.IsCPU() ? stats_.ConstHostSpan()
276+
: (stats_.SetDevice(device), stats_.ConstDeviceSpan());
271277
}
272278

273279
/*! \brief get node statistics given nid */

plugin/sycl/predictor/predictor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ class DeviceModel {
7575
if (model.trees[tree_idx]->HasCategoricalSplit()) {
7676
LOG(FATAL) << "Categorical features are not yet supported by sycl";
7777
}
78-
n_nodes += model.trees[tree_idx]->GetNodes().size();
78+
n_nodes += model.trees[tree_idx]->Size();
7979
first_node_position_host[tree_idx - tree_begin + 1] = n_nodes;
8080
}
8181

8282
nodes.Resize(n_nodes);
8383
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
84-
auto& src_nodes = model.trees[tree_idx]->GetNodes();
84+
auto const& src_nodes = model.trees[tree_idx]->GetNodes(DeviceOrd::CPU());
8585
size_t n_nodes_shift = first_node_position_host[tree_idx - tree_begin];
8686
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++) {
8787
nodes.HostVector()[node_idx + n_nodes_shift] = static_cast<Node>(src_nodes[node_idx]);

python-package/xgboost/testing/ordinal.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,11 @@ def run_thread_safety(DMatrixT: Type) -> bool:
434434
return True
435435

436436
futures = []
437+
n_cpus = os.cpu_count()
438+
assert n_cpus is not None
437439
for dm in (DMatrix, QuantileDMatrix):
438-
with ThreadPoolExecutor(max_workers=10) as e:
439-
for _ in range(10):
440+
with ThreadPoolExecutor(max_workers=max(n_cpus, 10)) as e:
441+
for _ in range(32):
440442
fut = e.submit(run_thread_safety, dm)
441443
futures.append(fut)
442444

src/predictor/cpu_predictor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ class ColumnSplitHelper {
640640
tree_offsets_.resize(n_trees);
641641
for (decltype(tree_begin) i = 0; i < n_trees; i++) {
642642
auto const &tree = *model_.trees[tree_begin_ + i];
643-
tree_sizes_[i] = tree.GetNodes().size();
643+
tree_sizes_[i] = tree.Size();
644644
}
645645
// std::exclusive_scan (only available in c++17) equivalent to get tree offsets.
646646
tree_offsets_[0] = 0;

src/predictor/gpu_predictor.cu

Lines changed: 282 additions & 459 deletions
Large diffs are not rendered by default.

src/tree/hist/evaluate_splits.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class HistEvaluator {
447447
tree[candidate.nid].SplitIndex(), left_weight, right_weight);
448448
evaluator = tree_evaluator_.GetEvaluator();
449449

450-
snode_.resize(tree.GetNodes().size());
450+
snode_.resize(tree.Size());
451451
snode_.at(left_child).stats = candidate.split.left_sum;
452452
snode_.at(left_child).root_gain =
453453
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.left_sum});

src/tree/tree_view.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ auto DispatchWeight(DeviceOrd device, RegTree const* tree) {
3131

3232
ScalarTreeView::ScalarTreeView(Context const* ctx, RegTree const* tree)
3333
: CategoriesMixIn{tree->GetCategoriesMatrix(ctx->Device())},
34-
nodes{tree->GetNodes().data()},
35-
stats{tree->GetStats().data()},
34+
nodes{tree->GetNodes(ctx->Device()).data()},
35+
stats{tree->GetStats(ctx->Device()).data()},
3636
n{tree->NumNodes()} {
3737
CHECK(!tree->IsMultiTarget());
3838
}

src/tree/tree_view.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,22 @@ struct CategoriesMixIn {
7474
RegTree::CategoricalSplitMatrix cats;
7575

7676
[[nodiscard]] XGBOOST_DEVICE bool HasCategoricalSplit() const { return !cats.categories.empty(); }
77-
[[nodiscard]] XGBOOST_DEVICE RegTree::CategoricalSplitMatrix GetCategoriesMatrix() const {
77+
[[nodiscard]] XGBOOST_DEVICE RegTree::CategoricalSplitMatrix const& GetCategoriesMatrix() const {
7878
return cats;
7979
}
8080
/**
8181
* @brief Get the bit storage of categories used by a node.
8282
*/
83-
[[nodiscard]] common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
83+
[[nodiscard]] XGBOOST_DEVICE common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
8484
auto node_ptr = this->GetCategoriesMatrix().node_ptr;
8585
auto categories = this->GetCategoriesMatrix().categories;
8686
auto segment = node_ptr[nidx];
8787
auto node_cats = categories.subspan(segment.beg, segment.size);
8888
return node_cats;
8989
}
90-
[[nodiscard]] FeatureType SplitType(bst_node_t nidx) const { return cats.split_type[nidx]; }
90+
[[nodiscard]] XGBOOST_DEVICE FeatureType SplitType(bst_node_t nidx) const {
91+
return cats.split_type[nidx];
92+
}
9193
};
9294

9395
/**
@@ -142,20 +144,20 @@ struct ScalarTreeView : public WalkTreeMixIn<ScalarTreeView>, public CategoriesM
142144
}
143145

144146
[[nodiscard]] RTreeNodeStat const& Stat(bst_node_t nidx) const { return stats[nidx]; }
145-
[[nodiscard]] auto SumHess(bst_node_t nidx) const { return stats[nidx].sum_hess; }
146-
[[nodiscard]] auto LossChg(bst_node_t nidx) const { return stats[nidx].loss_chg; }
147+
[[nodiscard]] XGBOOST_DEVICE auto SumHess(bst_node_t nidx) const { return stats[nidx].sum_hess; }
148+
[[nodiscard]] XGBOOST_DEVICE auto LossChg(bst_node_t nidx) const { return stats[nidx].loss_chg; }
147149

148150
XGBOOST_DEVICE explicit ScalarTreeView(RegTree::Node const* nodes, RTreeNodeStat const* stats,
149151
RegTree::CategoricalSplitMatrix cats, bst_node_t n_nodes)
150152
: CategoriesMixIn{std::move(cats)}, nodes{nodes}, stats{stats}, n{n_nodes} {}
151153

152-
/** @brief Create a device view, not implemented yet. */
154+
/** @brief Create a device view */
153155
explicit ScalarTreeView(Context const* ctx, RegTree const* tree);
154156
/** @brief Create a host view */
155157
explicit ScalarTreeView(RegTree const* tree)
156158
: CategoriesMixIn{tree->GetCategoriesMatrix(DeviceOrd::CPU())},
157-
nodes{tree->GetNodes().data()},
158-
stats{tree->GetStats().data()},
159+
nodes{tree->GetNodes(DeviceOrd::CPU()).data()},
160+
stats{tree->GetStats(DeviceOrd::CPU()).data()},
159161
n{tree->NumNodes()} {
160162
CHECK(!tree->IsMultiTarget());
161163
}

src/tree/updater_gpu_hist.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,8 @@ struct GPUHistMakerDevice {
669669
// Use the nodes from tree, the leaf value might be changed by the objective since the
670670
// last update tree call.
671671
dh::CachingDeviceUVector<RegTree::Node> nodes;
672-
dh::CopyTo(p_tree->GetNodes(), &nodes, this->ctx_->CUDACtx()->Stream());
672+
// We can remove the CPU copy once we refactor the GPU hist to use the device tree.
673+
dh::CopyTo(p_tree->GetNodes(DeviceOrd::CPU()), &nodes, this->ctx_->CUDACtx()->Stream());
673674
common::Span<RegTree::Node> d_nodes = dh::ToSpan(nodes);
674675
CHECK_EQ(out_preds_d.Shape(1), 1);
675676
dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(),

tests/cpp/helpers.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,10 @@ class CudaArrayIterForTest : public ArrayIterForTest {
445445
public:
446446
explicit CudaArrayIterForTest(float sparsity, size_t rows = Rows(), size_t cols = Cols(),
447447
size_t batches = Batches());
448+
explicit CudaArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
449+
std::size_t n_samples, bst_feature_t n_features,
450+
std::size_t n_batches)
451+
: ArrayIterForTest{ctx, data, n_samples, n_features, n_batches} {};
448452
int Next() override;
449453
~CudaArrayIterForTest() override = default;
450454
};

0 commit comments

Comments
 (0)