Skip to content

Commit 70b9b10

Browse files
authored
Provide device storage for scalar tree. (#11750)
- Cleanup various accessors. - Use the view object for all const access.
1 parent 8b55aee commit 70b9b10

33 files changed

+655
-633
lines changed

include/xgboost/multi_target_tree_model.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,37 +74,16 @@ class MultiTargetTree : public Model {
7474
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
7575
return left_.ConstHostVector()[nidx] == InvalidNodeId();
7676
}
77-
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
78-
return parent_.ConstHostVector().at(nidx);
79-
}
8077
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
8178
return left_.ConstHostVector().at(nidx);
8279
}
8380
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
8481
return right_.ConstHostVector().at(nidx);
8582
}
8683

87-
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
88-
return split_index_.ConstHostVector()[nidx];
89-
}
90-
[[nodiscard]] float SplitCond(bst_node_t nidx) const {
91-
return split_conds_.ConstHostVector()[nidx];
92-
}
93-
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
94-
return default_left_.ConstHostVector()[nidx];
95-
}
96-
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
97-
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
98-
}
99-
10084
[[nodiscard]] bst_target_t NumTargets() const;
101-
10285
[[nodiscard]] auto NumLeaves() const { return this->weights_.Size() / this->NumTargets(); }
10386

104-
[[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
105-
auto p = this->Parent(nidx);
106-
return nidx == this->LeftChild(p);
107-
}
10887
[[nodiscard]] std::size_t Size() const;
10988
[[nodiscard]] MultiTargetTree* Copy(TreeParam const* param) const;
11089

include/xgboost/tree_model.h

Lines changed: 82 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -204,42 +204,47 @@ class RegTree : public Model {
204204
Info info_;
205205
};
206206

207-
/*!
208-
* \brief change a non leaf node to a leaf node, delete its children
209-
* \param rid node id of the node
210-
* \param value new leaf value
207+
/**
208+
* @brief Change a non leaf node to a leaf node, delete its children
209+
*
210+
* @param nidx Node id
211+
* @param value The new leaf value
211212
*/
212-
void ChangeToLeaf(int rid, bst_float value) {
213-
CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
214-
CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
215-
this->DeleteNode(nodes_[rid].LeftChild());
216-
this->DeleteNode(nodes_[rid].RightChild());
217-
nodes_[rid].SetLeaf(value);
213+
void ChangeToLeaf(bst_node_t nidx, float value) {
214+
auto& h_nodes = nodes_.HostVector();
215+
CHECK(h_nodes[h_nodes[nidx].LeftChild()].IsLeaf());
216+
CHECK(h_nodes[h_nodes[nidx].RightChild()].IsLeaf());
217+
this->DeleteNode(h_nodes[nidx].LeftChild());
218+
this->DeleteNode(h_nodes[nidx].RightChild());
219+
h_nodes[nidx].SetLeaf(value);
218220
}
219-
/*!
220-
* \brief collapse a non leaf node to a leaf node, delete its children
221-
* \param rid node id of the node
222-
* \param value new leaf value
221+
/**
222+
* @brief Collapse a non leaf node to a leaf node, delete its children
223+
*
224+
* @param nidx Node id
225+
* @param value The new leaf value
223226
*/
224-
void CollapseToLeaf(int rid, bst_float value) {
225-
if (nodes_[rid].IsLeaf()) return;
226-
if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
227-
CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
227+
void CollapseToLeaf(bst_node_t nidx, float value) {
228+
auto& h_nodes = nodes_.HostVector();
229+
if (h_nodes[nidx].IsLeaf()) return;
230+
if (!h_nodes[h_nodes[nidx].LeftChild()].IsLeaf()) {
231+
CollapseToLeaf(h_nodes[nidx].LeftChild(), 0.0f);
228232
}
229-
if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
230-
CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
233+
if (!h_nodes[h_nodes[nidx].RightChild()].IsLeaf()) {
234+
CollapseToLeaf(h_nodes[nidx].RightChild(), 0.0f);
231235
}
232-
this->ChangeToLeaf(rid, value);
236+
this->ChangeToLeaf(nidx, value);
233237
}
234238

235239
RegTree() {
236-
nodes_.resize(param_.num_nodes);
237-
stats_.resize(param_.num_nodes);
238-
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
239-
split_categories_segments_.resize(param_.num_nodes);
240+
nodes_.HostVector().resize(param_.num_nodes);
241+
stats_.HostVector().resize(param_.num_nodes);
242+
split_types_.HostVector().resize(param_.num_nodes, FeatureType::kNumerical);
243+
split_categories_segments_.HostVector().resize(param_.num_nodes);
244+
auto& h_nodes = nodes_.HostVector();
240245
for (int i = 0; i < param_.num_nodes; i++) {
241-
nodes_[i].SetLeaf(0.0f);
242-
nodes_[i].SetParent(kInvalidNodeId);
246+
h_nodes[i].SetLeaf(0.0f);
247+
h_nodes[i].SetParent(kInvalidNodeId);
243248
}
244249
}
245250
/**
@@ -254,34 +259,28 @@ class RegTree : public Model {
254259
}
255260

256261
/*! \brief get node given nid */
257-
Node& operator[](int nid) {
258-
return nodes_[nid];
259-
}
260-
/*! \brief get node given nid */
261-
const Node& operator[](int nid) const {
262-
return nodes_[nid];
263-
}
262+
Node& operator[](bst_node_t nidx) { return nodes_.HostVector()[nidx]; }
264263

264+
public:
265265
/*! \brief get const reference to nodes */
266-
[[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
266+
[[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_.ConstHostVector(); }
267267

268268
/*! \brief get const reference to stats */
269-
[[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
269+
[[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const {
270+
return stats_.ConstHostVector();
271+
}
270272

271273
/*! \brief get node statistics given nid */
272274
RTreeNodeStat& Stat(int nid) {
273-
return stats_[nid];
274-
}
275-
/*! \brief get node statistics given nid */
276-
[[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
277-
return stats_[nid];
275+
return stats_.HostVector()[nid];
278276
}
279277

280278
void LoadModel(Json const& in) override;
281279
void SaveModel(Json* out) const override;
282280

283281
bool operator==(const RegTree& b) const {
284-
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
282+
return nodes_.ConstHostVector() == b.nodes_.ConstHostVector() &&
283+
stats_.ConstHostVector() == b.stats_.ConstHostVector() &&
285284
deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
286285
}
287286
/*!
@@ -344,9 +343,9 @@ class RegTree : public Model {
344343
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
345344
float left_sum, float right_sum);
346345
/**
347-
* \brief Whether this tree has categorical split.
346+
* @brief Whether this tree has categorical split.
348347
*/
349-
[[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
348+
[[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.Empty(); }
350349
/**
351350
* \brief Whether this is a multi-target tree.
352351
*/
@@ -391,26 +390,16 @@ class RegTree : public Model {
391390
*/
392391
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
393392
/**
394-
* \brief Set the leaf weight for a multi-target tree.
393+
* @brief Set the leaf weight for a multi-target tree.
395394
*/
396395
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
397396
CHECK(IsMultiTarget());
398397
return this->p_mt_tree_->SetLeaf(nidx, weight);
399398
}
400-
401-
/*!
402-
* \brief get maximum depth
403-
* \param nid node id
404-
*/
405-
[[nodiscard]] int MaxDepth(int nid) const {
406-
if (nodes_[nid].IsLeaf()) return 0;
407-
return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
408-
}
409-
410-
/*!
411-
* \brief get maximum depth
399+
/**
400+
* @brief Get the maximum depth.
412401
*/
413-
int MaxDepth() { return MaxDepth(0); }
402+
[[nodiscard]] bst_node_t MaxDepth() const;
414403

415404
/*!
416405
* \brief dense feature vector that can be taken by RegTree
@@ -474,35 +463,24 @@ class RegTree : public Model {
474463
*/
475464
[[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
476465
std::string format) const;
477-
/*!
478-
* \brief Get split type for a node.
479-
* \param nidx Index of node.
480-
* \return The type of this split. For leaf node it's always kNumerical.
481-
*/
482-
[[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
483-
/*!
484-
* \brief Get split types for all nodes.
466+
/**
467+
* @brief Get split types for all nodes.
485468
*/
486-
[[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
487-
return split_types_;
469+
[[nodiscard]] common::Span<FeatureType const> GetSplitTypes(DeviceOrd device) const {
470+
return device.IsCPU() ? split_types_.ConstHostSpan()
471+
: (split_types_.SetDevice(device), split_types_.ConstDeviceSpan());
488472
}
489-
[[nodiscard]] common::Span<uint32_t const> GetSplitCategories() const {
490-
return split_categories_;
473+
[[nodiscard]] common::Span<uint32_t const> GetSplitCategories(DeviceOrd device) const {
474+
return device.IsCPU()
475+
? split_categories_.ConstHostSpan()
476+
: (split_categories_.SetDevice(device), split_categories_.ConstDeviceSpan());
491477
}
492-
/*!
493-
* \brief Get the bit storage for categories
494-
*/
495-
[[nodiscard]] common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
496-
auto node_ptr = GetCategoriesMatrix().node_ptr;
497-
auto categories = GetCategoriesMatrix().categories;
498-
auto segment = node_ptr[nidx];
499-
auto node_cats = categories.subspan(segment.beg, segment.size);
500-
return node_cats;
478+
[[nodiscard]] auto const& GetSplitCategoriesPtr() const {
479+
return split_categories_segments_.ConstHostVector();
501480
}
502-
[[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
503481

504482
/**
505-
* \brief CSR-like matrix for categorical splits.
483+
* @brief CSR-like matrix for categorical splits.
506484
*
507485
* The fields of split_categories_segments_[i] are set such that the range
508486
* node_ptr[beg:(beg+size)] stores the bitset for the matching categories for the
@@ -518,78 +496,36 @@ class RegTree : public Model {
518496
common::Span<Segment const> node_ptr;
519497
};
520498

521-
[[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix() const {
499+
[[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix(DeviceOrd device) const {
522500
CategoricalSplitMatrix view;
523-
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
524-
view.categories = this->GetSplitCategories();
525-
view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
501+
view.split_type = this->GetSplitTypes(device);
502+
view.categories = this->GetSplitCategories(device);
503+
if (device.IsCPU()) {
504+
view.node_ptr = split_categories_segments_.ConstHostSpan();
505+
} else {
506+
split_categories_segments_.SetDevice(device);
507+
view.node_ptr = split_categories_segments_.ConstDeviceSpan();
508+
}
526509
return view;
527510
}
528511

529-
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
530-
if (IsMultiTarget()) {
531-
return this->p_mt_tree_->SplitIndex(nidx);
532-
}
533-
return (*this)[nidx].SplitIndex();
534-
}
535-
[[nodiscard]] float SplitCond(bst_node_t nidx) const {
536-
if (IsMultiTarget()) {
537-
return this->p_mt_tree_->SplitCond(nidx);
538-
}
539-
return (*this)[nidx].SplitCond();
540-
}
541-
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
542-
if (IsMultiTarget()) {
543-
return this->p_mt_tree_->DefaultLeft(nidx);
544-
}
545-
return (*this)[nidx].DefaultLeft();
546-
}
547-
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
548-
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
549-
}
550-
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
551-
if (IsMultiTarget()) {
552-
return nidx == kRoot;
553-
}
554-
return (*this)[nidx].IsRoot();
555-
}
556-
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
557-
if (IsMultiTarget()) {
558-
return this->p_mt_tree_->IsLeaf(nidx);
559-
}
560-
return (*this)[nidx].IsLeaf();
561-
}
562-
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
563-
if (IsMultiTarget()) {
564-
return this->p_mt_tree_->Parent(nidx);
565-
}
566-
return (*this)[nidx].Parent();
567-
}
568512
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
569513
if (IsMultiTarget()) {
570514
return this->p_mt_tree_->LeftChild(nidx);
571515
}
572-
return (*this)[nidx].LeftChild();
516+
return nodes_.ConstHostVector()[nidx].LeftChild();
573517
}
574518
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
575519
if (IsMultiTarget()) {
576520
return this->p_mt_tree_->RightChild(nidx);
577521
}
578-
return (*this)[nidx].RightChild();
579-
}
580-
[[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
581-
if (IsMultiTarget()) {
582-
CHECK_NE(nidx, kRoot);
583-
auto p = this->p_mt_tree_->Parent(nidx);
584-
return nidx == this->p_mt_tree_->LeftChild(p);
585-
}
586-
return (*this)[nidx].IsLeftChild();
522+
return nodes_.ConstHostVector()[nidx].RightChild();
587523
}
588524
[[nodiscard]] bst_node_t Size() const {
589525
if (IsMultiTarget()) {
590526
return this->p_mt_tree_->Size();
591527
}
592-
return this->nodes_.size();
528+
return this->nodes_.Size();
593529
}
594530

595531
[[nodiscard]] RegTree* Copy() const;
@@ -603,17 +539,17 @@ class RegTree : public Model {
603539
/*! \brief model parameter */
604540
TreeParam param_;
605541
// vector of nodes
606-
std::vector<Node> nodes_;
542+
HostDeviceVector<Node> nodes_;
607543
// free node space, used during training process
608544
std::vector<int> deleted_nodes_;
609545
// stats of nodes
610-
std::vector<RTreeNodeStat> stats_;
611-
std::vector<FeatureType> split_types_;
546+
HostDeviceVector<RTreeNodeStat> stats_;
547+
HostDeviceVector<FeatureType> split_types_;
612548

613549
// Categories for each internal node.
614-
std::vector<uint32_t> split_categories_;
550+
HostDeviceVector<uint32_t> split_categories_;
615551
// Ptr to split categories of each node.
616-
std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
552+
HostDeviceVector<CategoricalSplitMatrix::Segment> split_categories_segments_;
617553
// ptr to multi-target tree with vector leaf.
618554
std::unique_ptr<MultiTargetTree> p_mt_tree_;
619555
// allocate a new node,
@@ -622,17 +558,17 @@ class RegTree : public Model {
622558
if (param_.num_deleted != 0) {
623559
int nid = deleted_nodes_.back();
624560
deleted_nodes_.pop_back();
625-
nodes_[nid].Reuse();
561+
nodes_.HostVector()[nid].Reuse();
626562
--param_.num_deleted;
627563
return nid;
628564
}
629565
int nd = param_.num_nodes++;
630566
CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
631567
<< "number of nodes in the tree exceed 2^31";
632-
nodes_.resize(param_.num_nodes);
633-
stats_.resize(param_.num_nodes);
634-
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
635-
split_categories_segments_.resize(param_.num_nodes);
568+
nodes_.HostVector().resize(param_.num_nodes);
569+
stats_.HostVector().resize(param_.num_nodes);
570+
split_types_.HostVector().resize(param_.num_nodes, FeatureType::kNumerical);
571+
split_categories_segments_.HostVector().resize(param_.num_nodes);
636572
return nd;
637573
}
638574
// delete a tree node, keep the parent field to allow trace back
@@ -646,7 +582,7 @@ class RegTree : public Model {
646582
}
647583

648584
deleted_nodes_.push_back(nid);
649-
nodes_[nid].MarkDelete();
585+
nodes_.HostVector()[nid].MarkDelete();
650586
++param_.num_deleted;
651587
}
652588
};

0 commit comments

Comments
 (0)