@@ -74,20 +74,22 @@ struct CategoriesMixIn {
7474 RegTree::CategoricalSplitMatrix cats;
7575
7676 [[nodiscard]] XGBOOST_DEVICE bool HasCategoricalSplit () const { return !cats.categories .empty (); }
77- [[nodiscard]] XGBOOST_DEVICE RegTree::CategoricalSplitMatrix GetCategoriesMatrix () const {
77+ [[nodiscard]] XGBOOST_DEVICE RegTree::CategoricalSplitMatrix const & GetCategoriesMatrix () const {
7878 return cats;
7979 }
8080 /* *
8181 * @brief Get the bit storage of categories used by a node.
8282 */
83- [[nodiscard]] common::Span<uint32_t const > NodeCats (bst_node_t nidx) const {
83+ [[nodiscard]] XGBOOST_DEVICE common::Span<uint32_t const > NodeCats (bst_node_t nidx) const {
8484 auto node_ptr = this ->GetCategoriesMatrix ().node_ptr ;
8585 auto categories = this ->GetCategoriesMatrix ().categories ;
8686 auto segment = node_ptr[nidx];
8787 auto node_cats = categories.subspan (segment.beg , segment.size );
8888 return node_cats;
8989 }
90- [[nodiscard]] FeatureType SplitType (bst_node_t nidx) const { return cats.split_type [nidx]; }
90+ [[nodiscard]] XGBOOST_DEVICE FeatureType SplitType (bst_node_t nidx) const {
91+ return cats.split_type [nidx];
92+ }
9193};
9294
9395/* *
@@ -142,20 +144,20 @@ struct ScalarTreeView : public WalkTreeMixIn<ScalarTreeView>, public CategoriesM
142144 }
143145
144146 [[nodiscard]] RTreeNodeStat const & Stat (bst_node_t nidx) const { return stats[nidx]; }
145- [[nodiscard]] auto SumHess (bst_node_t nidx) const { return stats[nidx].sum_hess ; }
146- [[nodiscard]] auto LossChg (bst_node_t nidx) const { return stats[nidx].loss_chg ; }
147+ [[nodiscard]] XGBOOST_DEVICE auto SumHess (bst_node_t nidx) const { return stats[nidx].sum_hess ; }
148+ [[nodiscard]] XGBOOST_DEVICE auto LossChg (bst_node_t nidx) const { return stats[nidx].loss_chg ; }
147149
148150 XGBOOST_DEVICE explicit ScalarTreeView (RegTree::Node const * nodes, RTreeNodeStat const * stats,
149151 RegTree::CategoricalSplitMatrix cats, bst_node_t n_nodes)
150152 : CategoriesMixIn{std::move (cats)}, nodes{nodes}, stats{stats}, n{n_nodes} {}
151153
152- /* * @brief Create a device view, not implemented yet. */
154+ /* * @brief Create a device view */
153155 explicit ScalarTreeView (Context const * ctx, RegTree const * tree);
154156 /* * @brief Create a host view */
155157 explicit ScalarTreeView (RegTree const * tree)
156158 : CategoriesMixIn{tree->GetCategoriesMatrix (DeviceOrd::CPU ())},
157- nodes{tree->GetNodes ().data ()},
158- stats{tree->GetStats ().data ()},
159+ nodes{tree->GetNodes (DeviceOrd::CPU () ).data ()},
160+ stats{tree->GetStats (DeviceOrd::CPU () ).data ()},
159161 n{tree->NumNodes ()} {
160162 CHECK (!tree->IsMultiTarget ());
161163 }
0 commit comments