11/* *
22 * Copyright 2014-2025, XGBoost Contributors
33 *
4- * \ brief model structure for tree
4+ * @ brief model structure for tree
55 * \author Tianqi Chen
66 */
77#ifndef XGBOOST_TREE_MODEL_H_
1010#include < xgboost/base.h>
1111#include < xgboost/data.h>
1212#include < xgboost/feature_map.h>
13- #include < xgboost/linalg.h> // for VectorView
13+ #include < xgboost/host_device_vector.h> // for HostDeviceVector
14+ #include < xgboost/linalg.h> // for VectorView
1415#include < xgboost/logging.h>
1516#include < xgboost/model.h>
1617#include < xgboost/multi_target_tree_model.h> // for MultiTargetTree
1718
1819#include < algorithm>
1920#include < cstring>
20- #include < limits>
21- #include < memory> // for make_unique
22- #include < stack>
21+ #include < limits> // for numeric_limits
22+ #include < memory> // for unique_ptr
2323#include < string>
24+ #include < type_traits> // for is_signed_v
2425#include < vector>
2526
2627namespace xgboost {
28+
29+ namespace tree {
30+ struct ScalarTreeView ;
31+ struct MultiTargetTreeView ;
32+ }
33+
2734class Json ;
2835
2936/* * @brief meta parameters of the tree */
@@ -46,67 +53,39 @@ struct TreeParam {
4653 void ToJson (Json* p_out) const ;
4754};
4855
49- /* ! \ brief node statistics used in regression tree */
56+ /* * @ brief node statistics used in regression tree */
5057struct RTreeNodeStat {
51- /* ! \ brief loss change caused by current split */
52- bst_float loss_chg;
53- /* ! \ brief sum of hessian values, used to measure coverage of data */
54- bst_float sum_hess;
55- /* ! \ brief weight of current node */
56- bst_float base_weight;
57- /* ! \ brief number of child that is leaf node known up to now */
58- int leaf_child_cnt {0 };
58+ /* * @ brief loss change caused by current split */
59+ float loss_chg;
60+ /* * @ brief sum of hessian values, used to measure coverage of data */
61+ float sum_hess;
62+ /* * @ brief weight of current node */
63+ float base_weight;
64+ /* * @ brief number of child that is leaf node known up to now */
65+ int leaf_child_cnt{0 };
5966
6067 RTreeNodeStat () = default ;
61- RTreeNodeStat (float loss_chg, float sum_hess, float weight) :
62- loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
68+ RTreeNodeStat (float loss_chg, float sum_hess, float weight)
69+ : loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
6370 bool operator ==(const RTreeNodeStat& b) const {
64- return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
65- base_weight == b. base_weight && leaf_child_cnt == b.leaf_child_cnt ;
71+ return loss_chg == b.loss_chg && sum_hess == b.sum_hess && base_weight == b. base_weight &&
72+ leaf_child_cnt == b.leaf_child_cnt ;
6673 }
6774};
6875
6976/* *
70- * \brief Helper for defining copyable data structure that contains unique pointers.
71- */
72- template <typename T>
73- class CopyUniquePtr {
74- std::unique_ptr<T> ptr_{nullptr };
75-
76- public:
77- CopyUniquePtr () = default ;
78- CopyUniquePtr (CopyUniquePtr const & that) {
79- ptr_.reset (nullptr );
80- if (that.ptr_ ) {
81- ptr_ = std::make_unique<T>(*that);
82- }
83- }
84- T* get () const noexcept { return ptr_.get (); } // NOLINT
85-
86- T& operator *() { return *ptr_; }
87- T* operator ->() noexcept { return this ->get (); }
88-
89- T const & operator *() const { return *ptr_; }
90- T const * operator ->() const noexcept { return this ->get (); }
91-
92- explicit operator bool () const { return static_cast <bool >(ptr_); }
93- bool operator !() const { return !ptr_; }
94- void reset (T* ptr) { ptr_.reset (ptr); } // NOLINT
95- };
96-
97- /* *
98- * \brief define regression tree to be the most common tree model.
77+ * @brief define regression tree to be the most common tree model.
9978 *
10079 * This is the data structure used in xgboost's major tree models.
10180 */
10281class RegTree : public Model {
10382 public:
104- using SplitCondT = bst_float ;
83+ using SplitCondT = float ;
10584 static constexpr bst_node_t kInvalidNodeId {MultiTargetTree::InvalidNodeId ()};
10685 static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t >::max();
10786 static constexpr bst_node_t kRoot {0 };
10887
109- /* ! \ brief tree node */
88+ /* * @ brief tree node */
11089 class Node {
11190 public:
11291 XGBOOST_DEVICE Node () {
@@ -305,31 +284,6 @@ class RegTree : public Model {
305284 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
306285 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_ ;
307286 }
308- /* \brief Iterate through all nodes in this tree.
309- *
310- * \param Function that accepts a node index, and returns false when iteration should
311- * stop, otherwise returns true.
312- */
313- template <typename Func> void WalkTree (Func func) const {
314- std::stack<bst_node_t > nodes;
315- nodes.push (kRoot );
316- auto &self = *this ;
317- while (!nodes.empty ()) {
318- auto nidx = nodes.top ();
319- nodes.pop ();
320- if (!func (nidx)) {
321- return ;
322- }
323- auto left = self.LeftChild (nidx);
324- auto right = self.RightChild (nidx);
325- if (left != RegTree::kInvalidNodeId ) {
326- nodes.push (left);
327- }
328- if (right != RegTree::kInvalidNodeId ) {
329- nodes.push (right);
330- }
331- }
332- }
333287 /* !
334288 * \brief Compares whether 2 trees are equal from a user's perspective. The equality
335289 * compares only non-deleted nodes.
@@ -432,21 +386,10 @@ class RegTree : public Model {
432386 [[nodiscard]] bst_node_t GetNumLeaves () const ;
433387 [[nodiscard]] bst_node_t GetNumSplitNodes () const ;
434388
435- /* !
436- * \brief get current depth
437- * \param nid node id
389+ /* *
390+ * @brief Get the depth of a node.
438391 */
439- [[nodiscard]] std::int32_t GetDepth (bst_node_t nid) const {
440- if (IsMultiTarget ()) {
441- return this ->p_mt_tree_ ->Depth (nid);
442- }
443- int depth = 0 ;
444- while (!nodes_[nid].IsRoot ()) {
445- ++depth;
446- nid = nodes_[nid].Parent ();
447- }
448- return depth;
449- }
392+ [[nodiscard]] bst_node_t GetDepth (bst_node_t nidx) const ;
450393 /* *
451394 * \brief Set the leaf weight for a multi-target tree.
452395 */
@@ -649,6 +592,10 @@ class RegTree : public Model {
649592 return this ->nodes_ .size ();
650593 }
651594
595+ [[nodiscard]] RegTree* Copy () const ;
596+ tree::ScalarTreeView HostScView () const ;
597+ tree::MultiTargetTreeView HostMtView () const ;
598+
652599 private:
653600 template <bool typed>
654601 void LoadCategoricalSplit (Json const & in);
@@ -668,7 +615,7 @@ class RegTree : public Model {
668615 // Ptr to split categories of each node.
669616 std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
670617 // ptr to multi-target tree with vector leaf.
671- CopyUniquePtr <MultiTargetTree> p_mt_tree_;
618+ std::unique_ptr <MultiTargetTree> p_mt_tree_;
672619 // allocate a new node,
673620 // !!!!!! NOTE: may cause BUG here, nodes.resize
674621 bst_node_t AllocNode () {
0 commit comments