Skip to content

Commit 3868b5f

Browse files
authored
Share the model view for CPU and GPU. (#11759)
- Extract the device model view and share it with the CPU. - Ensure the model is locked when pulling tree data. - Use the host device vector to store the tree group.
1 parent b3ed14d commit 3868b5f

File tree

16 files changed

+350
-245
lines changed

16 files changed

+350
-245
lines changed

plugin/sycl/predictor/predictor.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,11 @@ class DeviceModel {
9090

9191
int num_group = model.learner_model_param->num_output_group;
9292
if (num_group > 1) {
93-
tree_group.Resize(model.tree_info.size());
93+
tree_group.Resize(model.tree_info.Size());
9494
auto& tree_group_host = tree_group.HostVector();
95-
for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++)
96-
tree_group_host[tree_idx] = model.tree_info[tree_idx];
95+
auto const& tree_group_in = model.tree_info.ConstHostVector();
96+
for (size_t tree_idx = 0; tree_idx < tree_group_in.size(); tree_idx++)
97+
tree_group_host[tree_idx] = tree_group_in[tree_idx];
9798
}
9899
}
99100
};

src/common/device_vector.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,9 @@ class DeviceUVectorImpl {
467467
this->size_ = n;
468468
this->capacity_ = n;
469469

470-
std::swap(this->data_, new_ptr);
470+
this->data_ = std::move(new_ptr);
471+
// swap failed with CTK12.8
472+
// std::swap(this->data_, new_ptr);
471473
}
472474
// Resize with init
473475
void resize(std::size_t n, T const &v) { // NOLINT

src/data/cat_container.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
#include "xgboost/base.h" // for bst_cat_t
1919
#include "xgboost/data.h" // for Entry
2020
#include "xgboost/host_device_vector.h" // for HostDeviceVector
21-
#include "xgboost/json.h" // for Json
2221

2322
namespace xgboost {
23+
class Json;
24+
2425
/**
2526
* @brief Error policy class used to interface with the encoder implementaion.
2627
*/

src/gbm/gbtree.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,21 @@ void GBTree::Configure(Args const& cfg) {
145145
}
146146
}
147147

148+
void GBTreeModel::InitTreesToUpdate() {
149+
if (trees_to_update.empty()) {
150+
for (auto& tree : trees) {
151+
trees_to_update.push_back(std::move(tree));
152+
}
153+
154+
trees.clear();
155+
param.num_trees = 0;
156+
tree_info.HostVector().clear();
157+
158+
iteration_indptr.clear();
159+
iteration_indptr.push_back(0);
160+
}
161+
}
162+
148163
void GPUCopyGradient(Context const*, linalg::Matrix<GradientPair> const*, bst_group_t,
149164
linalg::Matrix<GradientPair>*)
150165
#if defined(XGBOOST_USE_CUDA)
@@ -444,7 +459,9 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
444459

445460
auto& out_indptr = out_model.iteration_indptr;
446461
TreesOneGroup& out_trees = out_model.trees;
447-
std::vector<int32_t>& out_trees_info = out_model.tree_info;
462+
auto& out_tree_info = out_model.tree_info.HostVector();
463+
464+
auto const& in_tree_info = this->model_.tree_info.ConstHostVector();
448465

449466
bst_layer_t n_layers = (end - begin) / step;
450467
out_indptr.resize(n_layers + 1, 0);
@@ -462,8 +479,8 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
462479
std::unique_ptr<RegTree> new_tree{this->model_.trees.at(in_tree_idx)->Copy()};
463480
out_trees.emplace_back(std::move(new_tree));
464481

465-
bst_group_t group = this->model_.tree_info[in_tree_idx];
466-
out_trees_info.push_back(group);
482+
bst_group_t group = in_tree_info[in_tree_idx];
483+
out_tree_info.push_back(group);
467484

468485
out_model.iteration_indptr[out_l + 1]++;
469486
});
@@ -735,7 +752,7 @@ class Dart : public GBTree {
735752
auto layer_trees = [&]() {
736753
return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength();
737754
};
738-
755+
auto const& h_tree_info = this->model_.tree_info.ConstHostVector();
739756
for (bst_tree_t i = tree_begin; i < tree_end; i += 1) {
740757
if (training && std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
741758
continue;
@@ -749,7 +766,7 @@ class Dart : public GBTree {
749766

750767
// Multiple the weight to output prediction.
751768
auto w = this->weight_drop_.at(i);
752-
auto group = model_.tree_info.at(i);
769+
auto group = h_tree_info.at(i);
753770
CHECK_EQ(p_out_preds->predictions.Size(), predts.predictions.Size());
754771

755772
size_t n_rows = p_fmat->Info().num_row_;
@@ -815,6 +832,7 @@ class Dart : public GBTree {
815832
CHECK(success) << msg;
816833
};
817834

835+
auto const& h_tree_info = this->model_.tree_info.ConstHostVector();
818836
// Inplace predict is not used for training, so no need to drop tree.
819837
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
820838
predict_impl(i);
@@ -837,7 +855,7 @@ class Dart : public GBTree {
837855
}
838856
// Multiple the tree weight
839857
auto w = this->weight_drop_.at(i);
840-
auto group = model_.tree_info.at(i);
858+
auto group = h_tree_info.at(i);
841859
CHECK_EQ(predts.predictions.Size(), p_out_preds->predictions.Size());
842860

843861
size_t n_rows = p_fmat->Info().num_row_;

src/gbm/gbtree_model.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace xgboost::gbm {
1919
namespace {
2020
// For creating the tree indptr from old models.
2121
void MakeIndptr(GBTreeModel* out_model) {
22-
auto const& tree_info = out_model->tree_info;
22+
auto const& tree_info = out_model->tree_info.ConstHostVector();
2323
if (tree_info.empty()) {
2424
return;
2525
}
@@ -41,7 +41,7 @@ void MakeIndptr(GBTreeModel* out_model) {
4141
// Validate the consistency of the model.
4242
void Validate(GBTreeModel const& model) {
4343
CHECK_EQ(model.trees.size(), model.param.num_trees);
44-
CHECK_EQ(model.tree_info.size(), model.param.num_trees);
44+
CHECK_EQ(model.tree_info.Size(), model.param.num_trees);
4545
// True even if the model is empty since we should always have 0 as the first element.
4646
CHECK_EQ(model.iteration_indptr.back(), model.param.num_trees);
4747
}
@@ -61,9 +61,10 @@ void GBTreeModel::SaveModel(Json* p_out) const {
6161
trees_json[t] = std::move(jtree);
6262
});
6363

64-
std::vector<Json> tree_info_json(tree_info.size());
65-
for (size_t i = 0; i < tree_info.size(); ++i) {
66-
tree_info_json[i] = Integer(tree_info[i]);
64+
auto const& h_tree_info = tree_info.ConstHostVector();
65+
std::vector<Json> tree_info_json(tree_info.Size());
66+
for (size_t i = 0; i < h_tree_info.size(); ++i) {
67+
tree_info_json[i] = Integer(h_tree_info[i]);
6768
}
6869

6970
out["trees"] = Array(std::move(trees_json));
@@ -91,7 +92,8 @@ void GBTreeModel::LoadModel(Json const& in) {
9192

9293
auto const& tree_info_json = get<Array const>(jmodel.at("tree_info"));
9394
CHECK_EQ(tree_info_json.size(), param.num_trees);
94-
tree_info.resize(param.num_trees);
95+
auto& h_tree_info = this->tree_info.HostVector();
96+
h_tree_info.resize(param.num_trees);
9597

9698
common::ParallelFor(param.num_trees, ctx_->Threads(), [&](auto t) {
9799
auto tree_id = get<Integer const>(trees_json[t]["id"]);
@@ -100,7 +102,7 @@ void GBTreeModel::LoadModel(Json const& in) {
100102
});
101103

102104
for (bst_tree_t i = 0; i < param.num_trees; ++i) {
103-
tree_info[i] = get<Integer const>(tree_info_json[i]);
105+
h_tree_info[i] = get<Integer const>(tree_info_json[i]);
104106
}
105107

106108
auto indptr_it = jmodel.find("iteration_indptr");
@@ -144,10 +146,16 @@ bst_tree_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) {
144146
}
145147

146148
void GBTreeModel::CommitModelGroup(TreesOneGroup&& new_trees, bst_target_t group_idx) {
149+
auto& h_tree_info = this->tree_info.HostVector();
147150
for (auto& new_tree : new_trees) {
148151
trees.push_back(std::move(new_tree));
149-
tree_info.push_back(group_idx);
152+
h_tree_info.push_back(group_idx);
150153
}
151154
param.num_trees += static_cast<int>(new_trees.size());
152155
}
156+
157+
common::Span<bst_target_t const> GBTreeModel::TreeGroups(DeviceOrd device) const {
158+
return device.IsCPU() ? this->tree_info.ConstHostSpan()
159+
: (this->tree_info.SetDevice(device), this->tree_info.ConstDeviceSpan());
160+
}
153161
} // namespace xgboost::gbm

src/gbm/gbtree_model.h

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,42 @@
66
#ifndef XGBOOST_GBM_GBTREE_MODEL_H_
77
#define XGBOOST_GBM_GBTREE_MODEL_H_
88

9-
#include <dmlc/io.h>
109
#include <dmlc/parameter.h>
11-
#include <xgboost/context.h>
12-
#include <xgboost/learner.h>
13-
#include <xgboost/model.h>
14-
#include <xgboost/parameter.h>
15-
#include <xgboost/tree_model.h>
1610

1711
#include <memory>
1812
#include <string>
19-
#include <utility>
2013
#include <vector>
2114

2215
#include "../common/threading_utils.h"
2316
#include "../data/cat_container.h" // for CatContainer
17+
#include "xgboost/context.h"
18+
#include "xgboost/learner.h"
19+
#include "xgboost/model.h"
20+
#include "xgboost/tree_model.h"
2421

2522
namespace xgboost {
2623

2724
class Json;
2825

2926
namespace gbm {
3027
/**
31-
* \brief Container for all trees built (not update) for one group.
28+
* @brief Container for all trees built (not update) for one group.
3229
*/
3330
using TreesOneGroup = std::vector<std::unique_ptr<RegTree>>;
3431
/**
35-
* \brief Container for all trees built (not update) for one iteration.
32+
* @brief Container for all trees built (not update) for one iteration.
3633
*/
3734
using TreesOneIter = std::vector<TreesOneGroup>;
3835

39-
/*! \brief model parameters */
36+
/** @brief GBTree model parameters. */
4037
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
4138
public:
4239
/**
43-
* \brief number of trees
40+
* @brief The number of trees.
4441
*/
4542
std::int32_t num_trees{0};
4643
/**
47-
* \brief Number of trees for a forest.
44+
* @brief Number of trees for a single forest.
4845
*/
4946
std::int32_t num_parallel_tree{1};
5047

@@ -73,20 +70,8 @@ struct GBTreeModel : public Model {
7370
param.UpdateAllowUnknown(cfg);
7471
}
7572
}
76-
77-
void InitTreesToUpdate() {
78-
if (trees_to_update.empty()) {
79-
for (auto& tree : trees) {
80-
trees_to_update.push_back(std::move(tree));
81-
}
82-
trees.clear();
83-
param.num_trees = 0;
84-
tree_info.clear();
85-
86-
iteration_indptr.clear();
87-
iteration_indptr.push_back(0);
88-
}
89-
}
73+
/** @brief Move existing trees into the update queue. */
74+
void InitTreesToUpdate();
9075

9176
void SaveModel(Json* p_out) const override;
9277
void LoadModel(Json const& p_out) override;
@@ -114,9 +99,9 @@ struct GBTreeModel : public Model {
11499
return static_cast<std::int32_t>(iteration_indptr.size() - 1);
115100
}
116101

117-
// base margin
102+
/** @brief Global model properties. */
118103
LearnerModelParam const* learner_model_param;
119-
// model parameter
104+
/** @brief GBTree model parameters. */
120105
GBTreeModelParam param;
121106
/*! \brief vector of trees stored in the model */
122107
std::vector<std::unique_ptr<RegTree>> trees;
@@ -125,7 +110,7 @@ struct GBTreeModel : public Model {
125110
/**
126111
* @brief Group index for trees.
127112
*/
128-
std::vector<int> tree_info;
113+
HostDeviceVector<bst_target_t> tree_info;
129114
/**
130115
* @brief Number of trees accumulated for each iteration.
131116
*/
@@ -137,6 +122,10 @@ struct GBTreeModel : public Model {
137122
void Cats(std::shared_ptr<CatContainer> cats) { this->cats_ = cats; }
138123

139124
auto const* Ctx() const { return this->ctx_; }
125+
/**
126+
* @brief Getter for the tree group index.
127+
*/
128+
common::Span<bst_target_t const> TreeGroups(DeviceOrd device) const;
140129

141130
private:
142131
/**

0 commit comments

Comments
 (0)