@@ -204,42 +204,47 @@ class RegTree : public Model {
204204 Info info_;
205205 };
206206
207- /* !
208- * \brief change a non leaf node to a leaf node, delete its children
209- * \param rid node id of the node
210- * \param value new leaf value
207+ /* *
208+ * @brief Change a non leaf node to a leaf node, delete its children
209+ *
210+ * @param nidx Node id
211+ * @param value The new leaf value
211212 */
212- void ChangeToLeaf (int rid, bst_float value) {
213- CHECK (nodes_[nodes_[rid].LeftChild () ].IsLeaf ());
214- CHECK (nodes_[nodes_[rid].RightChild ()].IsLeaf ());
215- this ->DeleteNode (nodes_[rid].LeftChild ());
216- this ->DeleteNode (nodes_[rid].RightChild ());
217- nodes_[rid].SetLeaf (value);
213+ void ChangeToLeaf (bst_node_t nidx, float value) {
214+ auto & h_nodes = nodes_.HostVector ();
215+ CHECK (h_nodes[h_nodes[nidx].LeftChild ()].IsLeaf ());
216+ CHECK (h_nodes[h_nodes[nidx].RightChild ()].IsLeaf ());
217+ this ->DeleteNode (h_nodes[nidx].LeftChild ());
218+ this ->DeleteNode (h_nodes[nidx].RightChild ());
219+ h_nodes[nidx].SetLeaf (value);
218220 }
219- /* !
220- * \brief collapse a non leaf node to a leaf node, delete its children
221- * \param rid node id of the node
222- * \param value new leaf value
221+ /* *
222+ * @brief Collapse a non leaf node to a leaf node, delete its children
223+ *
224+ * @param nidx Node id
225+ * @param value The new leaf value
223226 */
224- void CollapseToLeaf (int rid, bst_float value) {
225- if (nodes_[rid].IsLeaf ()) return ;
226- if (!nodes_[nodes_[rid].LeftChild () ].IsLeaf ()) {
227- CollapseToLeaf (nodes_[rid].LeftChild (), 0 .0f );
227+ void CollapseToLeaf (bst_node_t nidx, float value) {
228+ auto & h_nodes = nodes_.HostVector ();
229+ if (h_nodes[nidx].IsLeaf ()) return ;
230+ if (!h_nodes[h_nodes[nidx].LeftChild ()].IsLeaf ()) {
231+ CollapseToLeaf (h_nodes[nidx].LeftChild (), 0 .0f );
228232 }
229- if (!nodes_[nodes_[rid ].RightChild () ].IsLeaf ()) {
230- CollapseToLeaf (nodes_[rid ].RightChild (), 0 .0f );
233+ if (!h_nodes[h_nodes[nidx ].RightChild ()].IsLeaf ()) {
234+ CollapseToLeaf (h_nodes[nidx ].RightChild (), 0 .0f );
231235 }
232- this ->ChangeToLeaf (rid , value);
236+ this ->ChangeToLeaf (nidx , value);
233237 }
234238
235239 RegTree () {
236- nodes_.resize (param_.num_nodes );
237- stats_.resize (param_.num_nodes );
238- split_types_.resize (param_.num_nodes , FeatureType::kNumerical );
239- split_categories_segments_.resize (param_.num_nodes );
240+ nodes_.HostVector ().resize (param_.num_nodes );
241+ stats_.HostVector ().resize (param_.num_nodes );
242+ split_types_.HostVector ().resize (param_.num_nodes , FeatureType::kNumerical );
243+ split_categories_segments_.HostVector ().resize (param_.num_nodes );
244+ auto & h_nodes = nodes_.HostVector ();
240245 for (int i = 0 ; i < param_.num_nodes ; i++) {
241- nodes_ [i].SetLeaf (0 .0f );
242- nodes_ [i].SetParent (kInvalidNodeId );
246+ h_nodes [i].SetLeaf (0 .0f );
247+ h_nodes [i].SetParent (kInvalidNodeId );
243248 }
244249 }
245250 /* *
@@ -254,34 +259,28 @@ class RegTree : public Model {
254259 }
255260
256261 /* ! \brief get node given nid */
257- Node& operator [](int nid) {
258- return nodes_[nid];
259- }
260- /* ! \brief get node given nid */
261- const Node& operator [](int nid) const {
262- return nodes_[nid];
263- }
262+ Node& operator [](bst_node_t nidx) { return nodes_.HostVector ()[nidx]; }
264263
264+ public:
265265 /* ! \brief get const reference to nodes */
266- [[nodiscard]] const std::vector<Node>& GetNodes () const { return nodes_; }
266+ [[nodiscard]] const std::vector<Node>& GetNodes () const { return nodes_. ConstHostVector () ; }
267267
268268 /* ! \brief get const reference to stats */
269- [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats () const { return stats_; }
269+ [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats () const {
270+ return stats_.ConstHostVector ();
271+ }
270272
271273 /* ! \brief get node statistics given nid */
272274 RTreeNodeStat& Stat (int nid) {
273- return stats_[nid];
274- }
275- /* ! \brief get node statistics given nid */
276- [[nodiscard]] const RTreeNodeStat& Stat (int nid) const {
277- return stats_[nid];
275+ return stats_.HostVector ()[nid];
278276 }
279277
280278 void LoadModel (Json const & in) override ;
281279 void SaveModel (Json* out) const override ;
282280
283281 bool operator ==(const RegTree& b) const {
284- return nodes_ == b.nodes_ && stats_ == b.stats_ &&
282+ return nodes_.ConstHostVector () == b.nodes_ .ConstHostVector () &&
283+ stats_.ConstHostVector () == b.stats_ .ConstHostVector () &&
285284 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_ ;
286285 }
287286 /* !
@@ -344,9 +343,9 @@ class RegTree : public Model {
344343 bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
345344 float left_sum, float right_sum);
346345 /* *
347- * \ brief Whether this tree has categorical split.
346+ * @ brief Whether this tree has categorical split.
348347 */
349- [[nodiscard]] bool HasCategoricalSplit () const { return !split_categories_.empty (); }
348+ [[nodiscard]] bool HasCategoricalSplit () const { return !split_categories_.Empty (); }
350349 /* *
351350 * \brief Whether this is a multi-target tree.
352351 */
@@ -391,26 +390,16 @@ class RegTree : public Model {
391390 */
392391 [[nodiscard]] bst_node_t GetDepth (bst_node_t nidx) const ;
393392 /* *
394- * \ brief Set the leaf weight for a multi-target tree.
393+ * @ brief Set the leaf weight for a multi-target tree.
395394 */
396395 void SetLeaf (bst_node_t nidx, linalg::VectorView<float const > weight) {
397396 CHECK (IsMultiTarget ());
398397 return this ->p_mt_tree_ ->SetLeaf (nidx, weight);
399398 }
400-
401- /* !
402- * \brief get maximum depth
403- * \param nid node id
404- */
405- [[nodiscard]] int MaxDepth (int nid) const {
406- if (nodes_[nid].IsLeaf ()) return 0 ;
407- return std::max (MaxDepth (nodes_[nid].LeftChild ()) + 1 , MaxDepth (nodes_[nid].RightChild ()) + 1 );
408- }
409-
410- /* !
411- * \brief get maximum depth
399+ /* *
400+ * @brief Get the maximum depth.
412401 */
413- int MaxDepth () { return MaxDepth (0 ); }
402+ [[nodiscard]] bst_node_t MaxDepth () const ;
414403
415404 /* !
416405 * \brief dense feature vector that can be taken by RegTree
@@ -474,35 +463,24 @@ class RegTree : public Model {
474463 */
475464 [[nodiscard]] std::string DumpModel (const FeatureMap& fmap, bool with_stats,
476465 std::string format) const ;
477- /* !
478- * \brief Get split type for a node.
479- * \param nidx Index of node.
480- * \return The type of this split. For leaf node it's always kNumerical.
481- */
482- [[nodiscard]] FeatureType NodeSplitType (bst_node_t nidx) const { return split_types_.at (nidx); }
483- /* !
484- * \brief Get split types for all nodes.
466+ /* *
467+ * @brief Get split types for all nodes.
485468 */
486- [[nodiscard]] std::vector<FeatureType> const & GetSplitTypes () const {
487- return split_types_;
469+ [[nodiscard]] common::Span<FeatureType const > GetSplitTypes (DeviceOrd device) const {
470+ return device.IsCPU () ? split_types_.ConstHostSpan ()
471+ : (split_types_.SetDevice (device), split_types_.ConstDeviceSpan ());
488472 }
489- [[nodiscard]] common::Span<uint32_t const > GetSplitCategories () const {
490- return split_categories_;
473+ [[nodiscard]] common::Span<uint32_t const > GetSplitCategories (DeviceOrd device) const {
474+ return device.IsCPU ()
475+ ? split_categories_.ConstHostSpan ()
476+ : (split_categories_.SetDevice (device), split_categories_.ConstDeviceSpan ());
491477 }
492- /* !
493- * \brief Get the bit storage for categories
494- */
495- [[nodiscard]] common::Span<uint32_t const > NodeCats (bst_node_t nidx) const {
496- auto node_ptr = GetCategoriesMatrix ().node_ptr ;
497- auto categories = GetCategoriesMatrix ().categories ;
498- auto segment = node_ptr[nidx];
499- auto node_cats = categories.subspan (segment.beg , segment.size );
500- return node_cats;
478+ [[nodiscard]] auto const & GetSplitCategoriesPtr () const {
479+ return split_categories_segments_.ConstHostVector ();
501480 }
502- [[nodiscard]] auto const & GetSplitCategoriesPtr () const { return split_categories_segments_; }
503481
504482 /* *
505- * \ brief CSR-like matrix for categorical splits.
483+ * @ brief CSR-like matrix for categorical splits.
506484 *
507485 * The fields of split_categories_segments_[i] are set such that the range
508486 * node_ptr[beg:(beg+size)] stores the bitset for the matching categories for the
@@ -518,78 +496,36 @@ class RegTree : public Model {
518496 common::Span<Segment const > node_ptr;
519497 };
520498
521- [[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix () const {
499+ [[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix (DeviceOrd device ) const {
522500 CategoricalSplitMatrix view;
523- view.split_type = common::Span<FeatureType const >(this ->GetSplitTypes ());
524- view.categories = this ->GetSplitCategories ();
525- view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const >(split_categories_segments_);
501+ view.split_type = this ->GetSplitTypes (device);
502+ view.categories = this ->GetSplitCategories (device);
503+ if (device.IsCPU ()) {
504+ view.node_ptr = split_categories_segments_.ConstHostSpan ();
505+ } else {
506+ split_categories_segments_.SetDevice (device);
507+ view.node_ptr = split_categories_segments_.ConstDeviceSpan ();
508+ }
526509 return view;
527510 }
528511
529- [[nodiscard]] bst_feature_t SplitIndex (bst_node_t nidx) const {
530- if (IsMultiTarget ()) {
531- return this ->p_mt_tree_ ->SplitIndex (nidx);
532- }
533- return (*this )[nidx].SplitIndex ();
534- }
535- [[nodiscard]] float SplitCond (bst_node_t nidx) const {
536- if (IsMultiTarget ()) {
537- return this ->p_mt_tree_ ->SplitCond (nidx);
538- }
539- return (*this )[nidx].SplitCond ();
540- }
541- [[nodiscard]] bool DefaultLeft (bst_node_t nidx) const {
542- if (IsMultiTarget ()) {
543- return this ->p_mt_tree_ ->DefaultLeft (nidx);
544- }
545- return (*this )[nidx].DefaultLeft ();
546- }
547- [[nodiscard]] bst_node_t DefaultChild (bst_node_t nidx) const {
548- return this ->DefaultLeft (nidx) ? this ->LeftChild (nidx) : this ->RightChild (nidx);
549- }
550- [[nodiscard]] bool IsRoot (bst_node_t nidx) const {
551- if (IsMultiTarget ()) {
552- return nidx == kRoot ;
553- }
554- return (*this )[nidx].IsRoot ();
555- }
556- [[nodiscard]] bool IsLeaf (bst_node_t nidx) const {
557- if (IsMultiTarget ()) {
558- return this ->p_mt_tree_ ->IsLeaf (nidx);
559- }
560- return (*this )[nidx].IsLeaf ();
561- }
562- [[nodiscard]] bst_node_t Parent (bst_node_t nidx) const {
563- if (IsMultiTarget ()) {
564- return this ->p_mt_tree_ ->Parent (nidx);
565- }
566- return (*this )[nidx].Parent ();
567- }
568512 [[nodiscard]] bst_node_t LeftChild (bst_node_t nidx) const {
569513 if (IsMultiTarget ()) {
570514 return this ->p_mt_tree_ ->LeftChild (nidx);
571515 }
572- return (* this )[nidx].LeftChild ();
516+ return nodes_. ConstHostVector ( )[nidx].LeftChild ();
573517 }
574518 [[nodiscard]] bst_node_t RightChild (bst_node_t nidx) const {
575519 if (IsMultiTarget ()) {
576520 return this ->p_mt_tree_ ->RightChild (nidx);
577521 }
578- return (*this )[nidx].RightChild ();
579- }
580- [[nodiscard]] bool IsLeftChild (bst_node_t nidx) const {
581- if (IsMultiTarget ()) {
582- CHECK_NE (nidx, kRoot );
583- auto p = this ->p_mt_tree_ ->Parent (nidx);
584- return nidx == this ->p_mt_tree_ ->LeftChild (p);
585- }
586- return (*this )[nidx].IsLeftChild ();
522+ return nodes_.ConstHostVector ()[nidx].RightChild ();
587523 }
588524 [[nodiscard]] bst_node_t Size () const {
589525 if (IsMultiTarget ()) {
590526 return this ->p_mt_tree_ ->Size ();
591527 }
592- return this ->nodes_ .size ();
528+ return this ->nodes_ .Size ();
593529 }
594530
595531 [[nodiscard]] RegTree* Copy () const ;
@@ -603,17 +539,17 @@ class RegTree : public Model {
603539 /* ! \brief model parameter */
604540 TreeParam param_;
605541 // vector of nodes
606- std::vector <Node> nodes_;
542+ HostDeviceVector <Node> nodes_;
607543 // free node space, used during training process
608544 std::vector<int > deleted_nodes_;
609545 // stats of nodes
610- std::vector <RTreeNodeStat> stats_;
611- std::vector <FeatureType> split_types_;
546+ HostDeviceVector <RTreeNodeStat> stats_;
547+ HostDeviceVector <FeatureType> split_types_;
612548
613549 // Categories for each internal node.
614- std::vector <uint32_t > split_categories_;
550+ HostDeviceVector <uint32_t > split_categories_;
615551 // Ptr to split categories of each node.
616- std::vector <CategoricalSplitMatrix::Segment> split_categories_segments_;
552+ HostDeviceVector <CategoricalSplitMatrix::Segment> split_categories_segments_;
617553 // ptr to multi-target tree with vector leaf.
618554 std::unique_ptr<MultiTargetTree> p_mt_tree_;
619555 // allocate a new node,
@@ -622,17 +558,17 @@ class RegTree : public Model {
622558 if (param_.num_deleted != 0 ) {
623559 int nid = deleted_nodes_.back ();
624560 deleted_nodes_.pop_back ();
625- nodes_[nid].Reuse ();
561+ nodes_. HostVector () [nid].Reuse ();
626562 --param_.num_deleted ;
627563 return nid;
628564 }
629565 int nd = param_.num_nodes ++;
630566 CHECK_LT (param_.num_nodes , std::numeric_limits<int >::max ())
631567 << " number of nodes in the tree exceed 2^31" ;
632- nodes_.resize (param_.num_nodes );
633- stats_.resize (param_.num_nodes );
634- split_types_.resize (param_.num_nodes , FeatureType::kNumerical );
635- split_categories_segments_.resize (param_.num_nodes );
568+ nodes_.HostVector (). resize (param_.num_nodes );
569+ stats_.HostVector (). resize (param_.num_nodes );
570+ split_types_.HostVector (). resize (param_.num_nodes , FeatureType::kNumerical );
571+ split_categories_segments_.HostVector (). resize (param_.num_nodes );
636572 return nd;
637573 }
638574 // delete a tree node, keep the parent field to allow trace back
@@ -646,7 +582,7 @@ class RegTree : public Model {
646582 }
647583
648584 deleted_nodes_.push_back (nid);
649- nodes_[nid].MarkDelete ();
585+ nodes_. HostVector () [nid].MarkDelete ();
650586 ++param_.num_deleted ;
651587 }
652588};
0 commit comments