Skip to content

Commit b58f701

Browse files
authored
Implement a tree view for the scalar tree. (#11741)
- Add scalar tree view. - Replace the GPU-specific tree view. - Replace the walk tree routine. - Replace the multi-target-specific prediction function.
1 parent 79ced22 commit b58f701

26 files changed

+612
-448
lines changed

R-package/src/Makevars.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ OBJECTS= \
9090
$(PKGROOT)/src/tree/param.o \
9191
$(PKGROOT)/src/tree/fit_stump.o \
9292
$(PKGROOT)/src/tree/tree_model.o \
93+
$(PKGROOT)/src/tree/tree_view.o \
9394
$(PKGROOT)/src/tree/tree_updater.o \
9495
$(PKGROOT)/src/tree/multi_target_tree_model.o \
9596
$(PKGROOT)/src/tree/updater_approx.o \

R-package/src/Makevars.win.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ OBJECTS= \
8989
$(PKGROOT)/src/tree/param.o \
9090
$(PKGROOT)/src/tree/fit_stump.o \
9191
$(PKGROOT)/src/tree/tree_model.o \
92+
$(PKGROOT)/src/tree/tree_view.o \
9293
$(PKGROOT)/src/tree/multi_target_tree_model.o \
9394
$(PKGROOT)/src/tree/tree_updater.o \
9495
$(PKGROOT)/src/tree/updater_approx.o \

include/xgboost/base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ using bst_bin_t = std::int32_t; // NOLINT
114114
*/
115115
using bst_idx_t = std::uint64_t; // NOLINT
116116
/**
117-
* \brief Type for tree node index.
117+
* \brief Type for tree node index and tree depth.
118118
*/
119119
using bst_node_t = std::int32_t; // NOLINT
120120
/**

include/xgboost/multi_target_tree_model.h

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,60 +15,21 @@
1515

1616
#include <cstddef> // for size_t
1717
#include <cstdint> // for uint8_t
18-
#include <mutex> // for mutex
1918
#include <vector> // for vector
2019

2120
namespace xgboost {
21+
namespace tree {
22+
struct MultiTargetTreeView;
23+
}
2224
struct TreeParam;
23-
/**
24-
* @brief A view to the @MultiTargetTree suitable for both host and device.
25-
*/
26-
struct MultiTargetTreeView {
27-
static bst_node_t constexpr InvalidNodeId() { return -1; }
28-
29-
bst_node_t const* left;
30-
bst_node_t const* right;
31-
bst_node_t const* parent;
32-
33-
bst_feature_t const* split_index;
34-
std::uint8_t const* default_left;
35-
float const* split_conds;
36-
37-
// The number of nodes
38-
std::size_t n{0};
39-
40-
linalg::MatrixView<float const> weights;
41-
42-
[[nodiscard]] XGBOOST_DEVICE bool IsLeaf(bst_node_t nidx) const {
43-
return left[nidx] == InvalidNodeId();
44-
}
45-
46-
[[nodiscard]] XGBOOST_DEVICE bst_node_t LeftChild(bst_node_t nidx) const { return left[nidx]; }
47-
[[nodiscard]] XGBOOST_DEVICE bst_node_t RightChild(bst_node_t nidx) const { return right[nidx]; }
48-
[[nodiscard]] XGBOOST_DEVICE bst_feature_t SplitIndex(bst_node_t nidx) const {
49-
return split_index[nidx];
50-
}
51-
[[nodiscard]] XGBOOST_DEVICE float SplitCond(bst_node_t nidx) const { return split_conds[nidx]; }
52-
[[nodiscard]] XGBOOST_DEVICE bool DefaultLeft(bst_node_t nidx) const {
53-
return default_left[nidx];
54-
}
55-
[[nodiscard]] XGBOOST_DEVICE bst_node_t DefaultChild(bst_node_t nidx) const {
56-
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
57-
}
58-
[[nodiscard]] XGBOOST_DEVICE linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
59-
return this->weights.Slice(nidx, linalg::All());
60-
}
61-
62-
[[nodiscard]] bst_target_t NumTargets() const { return this->weights.Shape(1); }
63-
[[nodiscard]] bst_node_t Size() const { return this->n; }
64-
};
6525

6626
/**
6727
* @brief Tree structure for multi-target model.
6828
*/
6929
class MultiTargetTree : public Model {
7030
public:
71-
static bst_node_t constexpr InvalidNodeId() { return MultiTargetTreeView::InvalidNodeId(); }
31+
static bst_node_t constexpr InvalidNodeId() { return -1; }
32+
friend struct tree::MultiTargetTreeView;
7233

7334
private:
7435
TreeParam const* param_;
@@ -80,8 +41,6 @@ class MultiTargetTree : public Model {
8041
HostDeviceVector<float> split_conds_;
8142
HostDeviceVector<float> weights_;
8243

83-
mutable std::mutex tree_view_lock_;
84-
8544
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
8645
auto beg = nidx * this->NumTargets();
8746
auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumTargets());
@@ -140,27 +99,27 @@ class MultiTargetTree : public Model {
14099

141100
[[nodiscard]] bst_target_t NumTargets() const;
142101

102+
[[nodiscard]] auto NumLeaves() const { return this->weights_.Size() / this->NumTargets(); }
103+
104+
[[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
105+
auto p = this->Parent(nidx);
106+
return nidx == this->LeftChild(p);
107+
}
143108
[[nodiscard]] std::size_t Size() const;
109+
[[nodiscard]] MultiTargetTree* Copy(TreeParam const* param) const;
144110

145-
[[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
146-
bst_node_t depth{0};
147-
while (Parent(nidx) != InvalidNodeId()) {
148-
++depth;
149-
nidx = Parent(nidx);
111+
common::Span<float const> Weights(DeviceOrd device) const {
112+
if (device.IsCPU()) {
113+
return this->weights_.ConstHostSpan();
150114
}
151-
return depth;
115+
this->weights_.SetDevice(device);
116+
return this->weights_.ConstDeviceSpan();
152117
}
153118

154119
[[nodiscard]] linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
155120
CHECK(IsLeaf(nidx));
156121
return this->NodeWeight(nidx);
157122
}
158-
/**
159-
* @brief Get a view to the tree.
160-
*
161-
* This method is NOT thread-safe.
162-
*/
163-
[[nodiscard]] MultiTargetTreeView View(Context const* ctx) const;
164123

165124
void LoadModel(Json const& in) override;
166125
void SaveModel(Json* out) const override;

include/xgboost/tree_model.h

Lines changed: 36 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* Copyright 2014-2025, XGBoost Contributors
33
*
4-
* \brief model structure for tree
4+
* @brief model structure for tree
55
* \author Tianqi Chen
66
*/
77
#ifndef XGBOOST_TREE_MODEL_H_
@@ -10,20 +10,27 @@
1010
#include <xgboost/base.h>
1111
#include <xgboost/data.h>
1212
#include <xgboost/feature_map.h>
13-
#include <xgboost/linalg.h> // for VectorView
13+
#include <xgboost/host_device_vector.h> // for HostDeviceVector
14+
#include <xgboost/linalg.h> // for VectorView
1415
#include <xgboost/logging.h>
1516
#include <xgboost/model.h>
1617
#include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
1718

1819
#include <algorithm>
1920
#include <cstring>
20-
#include <limits>
21-
#include <memory> // for make_unique
22-
#include <stack>
21+
#include <limits> // for numeric_limits
22+
#include <memory> // for unique_ptr
2323
#include <string>
24+
#include <type_traits> // for is_signed_v
2425
#include <vector>
2526

2627
namespace xgboost {
28+
29+
namespace tree {
30+
struct ScalarTreeView;
31+
struct MultiTargetTreeView;
32+
}
33+
2734
class Json;
2835

2936
/** @brief meta parameters of the tree */
@@ -46,67 +53,39 @@ struct TreeParam {
4653
void ToJson(Json* p_out) const;
4754
};
4855

49-
/*! \brief node statistics used in regression tree */
56+
/** @brief node statistics used in regression tree */
5057
struct RTreeNodeStat {
51-
/*! \brief loss change caused by current split */
52-
bst_float loss_chg;
53-
/*! \brief sum of hessian values, used to measure coverage of data */
54-
bst_float sum_hess;
55-
/*! \brief weight of current node */
56-
bst_float base_weight;
57-
/*! \brief number of child that is leaf node known up to now */
58-
int leaf_child_cnt {0};
58+
/** @brief loss change caused by current split */
59+
float loss_chg;
60+
/** @brief sum of hessian values, used to measure coverage of data */
61+
float sum_hess;
62+
/** @brief weight of current node */
63+
float base_weight;
64+
/** @brief number of child that is leaf node known up to now */
65+
int leaf_child_cnt{0};
5966

6067
RTreeNodeStat() = default;
61-
RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
62-
loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
68+
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
69+
: loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
6370
bool operator==(const RTreeNodeStat& b) const {
64-
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
65-
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
71+
return loss_chg == b.loss_chg && sum_hess == b.sum_hess && base_weight == b.base_weight &&
72+
leaf_child_cnt == b.leaf_child_cnt;
6673
}
6774
};
6875

6976
/**
70-
* \brief Helper for defining copyable data structure that contains unique pointers.
71-
*/
72-
template <typename T>
73-
class CopyUniquePtr {
74-
std::unique_ptr<T> ptr_{nullptr};
75-
76-
public:
77-
CopyUniquePtr() = default;
78-
CopyUniquePtr(CopyUniquePtr const& that) {
79-
ptr_.reset(nullptr);
80-
if (that.ptr_) {
81-
ptr_ = std::make_unique<T>(*that);
82-
}
83-
}
84-
T* get() const noexcept { return ptr_.get(); } // NOLINT
85-
86-
T& operator*() { return *ptr_; }
87-
T* operator->() noexcept { return this->get(); }
88-
89-
T const& operator*() const { return *ptr_; }
90-
T const* operator->() const noexcept { return this->get(); }
91-
92-
explicit operator bool() const { return static_cast<bool>(ptr_); }
93-
bool operator!() const { return !ptr_; }
94-
void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
95-
};
96-
97-
/**
98-
* \brief define regression tree to be the most common tree model.
77+
* @brief define regression tree to be the most common tree model.
9978
*
10079
* This is the data structure used in xgboost's major tree models.
10180
*/
10281
class RegTree : public Model {
10382
public:
104-
using SplitCondT = bst_float;
83+
using SplitCondT = float;
10584
static constexpr bst_node_t kInvalidNodeId{MultiTargetTree::InvalidNodeId()};
10685
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
10786
static constexpr bst_node_t kRoot{0};
10887

109-
/*! \brief tree node */
88+
/** @brief tree node */
11089
class Node {
11190
public:
11291
XGBOOST_DEVICE Node() {
@@ -305,31 +284,6 @@ class RegTree : public Model {
305284
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
306285
deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
307286
}
308-
/* \brief Iterate through all nodes in this tree.
309-
*
310-
* \param Function that accepts a node index, and returns false when iteration should
311-
* stop, otherwise returns true.
312-
*/
313-
template <typename Func> void WalkTree(Func func) const {
314-
std::stack<bst_node_t> nodes;
315-
nodes.push(kRoot);
316-
auto &self = *this;
317-
while (!nodes.empty()) {
318-
auto nidx = nodes.top();
319-
nodes.pop();
320-
if (!func(nidx)) {
321-
return;
322-
}
323-
auto left = self.LeftChild(nidx);
324-
auto right = self.RightChild(nidx);
325-
if (left != RegTree::kInvalidNodeId) {
326-
nodes.push(left);
327-
}
328-
if (right != RegTree::kInvalidNodeId) {
329-
nodes.push(right);
330-
}
331-
}
332-
}
333287
/*!
334288
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
335289
* compares only non-deleted nodes.
@@ -432,21 +386,10 @@ class RegTree : public Model {
432386
[[nodiscard]] bst_node_t GetNumLeaves() const;
433387
[[nodiscard]] bst_node_t GetNumSplitNodes() const;
434388

435-
/*!
436-
* \brief get current depth
437-
* \param nid node id
389+
/**
390+
* @brief Get the depth of a node.
438391
*/
439-
[[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
440-
if (IsMultiTarget()) {
441-
return this->p_mt_tree_->Depth(nid);
442-
}
443-
int depth = 0;
444-
while (!nodes_[nid].IsRoot()) {
445-
++depth;
446-
nid = nodes_[nid].Parent();
447-
}
448-
return depth;
449-
}
392+
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
450393
/**
451394
* \brief Set the leaf weight for a multi-target tree.
452395
*/
@@ -649,6 +592,10 @@ class RegTree : public Model {
649592
return this->nodes_.size();
650593
}
651594

595+
[[nodiscard]] RegTree* Copy() const;
596+
tree::ScalarTreeView HostScView() const;
597+
tree::MultiTargetTreeView HostMtView() const;
598+
652599
private:
653600
template <bool typed>
654601
void LoadCategoricalSplit(Json const& in);
@@ -668,7 +615,7 @@ class RegTree : public Model {
668615
// Ptr to split categories of each node.
669616
std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
670617
// ptr to multi-target tree with vector leaf.
671-
CopyUniquePtr<MultiTargetTree> p_mt_tree_;
618+
std::unique_ptr<MultiTargetTree> p_mt_tree_;
672619
// allocate a new node,
673620
// !!!!!! NOTE: may cause BUG here, nodes.resize
674621
bst_node_t AllocNode() {

src/gbm/gbtree.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
459459

460460
*out_of_bound =
461461
detail::SliceTrees(begin, end, step, this->model_, [&](auto in_tree_idx, auto out_l) {
462-
auto new_tree = std::make_unique<RegTree>(*this->model_.trees.at(in_tree_idx));
462+
std::unique_ptr<RegTree> new_tree{this->model_.trees.at(in_tree_idx)->Copy()};
463463
out_trees.emplace_back(std::move(new_tree));
464464

465465
bst_group_t group = this->model_.tree_info[in_tree_idx];

src/gbm/gbtree.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
#include <vector>
1919

2020
#include "../common/timer.h"
21-
#include "../tree/param.h" // TrainParam
21+
#include "../tree/param.h" // TrainParam
22+
#include "../tree/tree_view.h" // for WalkTree
2223
#include "gbtree_model.h"
2324
#include "xgboost/base.h"
2425
#include "xgboost/data.h"
@@ -231,7 +232,7 @@ class GBTree : public GradientBooster {
231232
for (auto idx : trees) {
232233
CHECK_LE(idx, total_n_trees) << "Invalid tree index.";
233234
auto const& tree = *model_.trees[idx];
234-
tree.WalkTree([&](bst_node_t nidx) {
235+
tree::WalkTree(tree, [&](auto const& tree, bst_node_t nidx) {
235236
if (!tree.IsLeaf(nidx)) {
236237
split_counts[tree.SplitIndex(nidx)]++;
237238
fn(tree, nidx, tree.SplitIndex(nidx));
@@ -246,18 +247,20 @@ class GBTree : public GradientBooster {
246247
gain_map[split] = split_counts[split];
247248
});
248249
} else if (importance_type == "gain" || importance_type == "total_gain") {
249-
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
250-
LOG(FATAL) << "gain/total_gain " << MTNotImplemented();
251-
}
252250
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
253-
gain_map[split] += tree.Stat(nidx).loss_chg;
251+
if constexpr (tree::IsScalarTree<decltype(tree)>()) {
252+
gain_map[split] += tree.Stat(nidx).loss_chg;
253+
} else {
254+
LOG(FATAL) << "gain/total_gain " << MTNotImplemented();
255+
}
254256
});
255257
} else if (importance_type == "cover" || importance_type == "total_cover") {
256-
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
257-
LOG(FATAL) << "cover/total_cover " << MTNotImplemented();
258-
}
259258
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
260-
gain_map[split] += tree.Stat(nidx).sum_hess;
259+
if constexpr (tree::IsScalarTree<decltype(tree)>()) {
260+
gain_map[split] += tree.Stat(nidx).sum_hess;
261+
} else {
262+
LOG(FATAL) << "cover/total_cover " << MTNotImplemented();
263+
}
261264
});
262265
} else {
263266
LOG(FATAL)

0 commit comments

Comments
 (0)