2727#include " utils.h" // for CheckProxyDMatrix
2828#include " xgboost/data.h"
2929#include " xgboost/host_device_vector.h"
30+ #include " xgboost/multi_target_tree_model.h" // for MultiTargetTree, MultiTargetTreeView
3031#include " xgboost/predictor.h"
3132#include " xgboost/tree_model.h"
3233#include " xgboost/tree_updater.h"
@@ -243,6 +244,55 @@ struct DeviceAdapterLoader {
243244 }
244245};
245246
247+ namespace multi {
248+ template <bool has_missing, bool has_categorical>
249+ XGBOOST_DEVICE bst_node_t GetNextNode (MultiTargetTreeView const & tree, bst_node_t const nidx,
250+ float fvalue, bool is_missing) {
251+ if (has_missing && is_missing) {
252+ return tree.DefaultChild (nidx);
253+ } else {
254+ return fvalue < tree.SplitCond (nidx) ? tree.LeftChild (nidx) : tree.RightChild (nidx);
255+ }
256+ }
257+
258+ template <bool has_missing, bool has_categorical, typename Loader>
259+ __device__ bst_node_t GetLeafIndex (bst_idx_t ridx, MultiTargetTreeView const & tree,
260+ Loader* loader) {
261+ bst_node_t nidx = 0 ;
262+ while (!tree.IsLeaf (nidx)) {
263+ float fvalue = loader->GetElement (ridx, tree.SplitIndex (nidx));
264+ bool is_missing = common::CheckNAN (fvalue);
265+ auto next = GetNextNode<has_missing, has_categorical>(tree, nidx, fvalue, is_missing);
266+ assert (nidx < next);
267+ nidx = next;
268+ }
269+ return nidx;
270+ }
271+
272+ template <bool has_missing, typename Loader>
273+ __device__ auto GetLeafWeight (bst_idx_t ridx, MultiTargetTreeView const & tree, Loader* loader) {
274+ bst_node_t nidx = GetLeafIndex<has_missing, false >(ridx, tree, loader);
275+ return tree.LeafValue (nidx);
276+ }
277+
278+ template <typename Loader, typename Data, bool has_missing, typename EncAccessor>
279+ __global__ void PredictKernel (Data data, common::Span<MultiTargetTreeView> trees, bool use_shared,
280+ float missing, linalg::MatrixView<float > d_out_predt,
281+ EncAccessor acc) {
282+ for (auto idx : dh::GridStrideRange (static_cast <std::size_t >(0 ), data.NumRows ())) {
283+ Loader loader{std::move (data), use_shared, static_cast <bst_feature_t >(data.NumCols ()),
284+ data.NumRows (), missing, std::move (acc)};
285+ for (auto const & tree : trees) {
286+ auto leaf = GetLeafWeight<has_missing>(idx, tree, &loader);
287+ for (std::size_t i = 0 , n = leaf.Shape (0 ); i < n; ++i) {
288+ d_out_predt (idx, i) += leaf (i);
289+ }
290+ }
291+ }
292+ }
293+ } // namespace multi
294+
295+ namespace scalar {
246296template <bool has_missing, bool has_categorical, typename Loader>
247297__device__ bst_node_t GetLeafIndex (bst_idx_t ridx, TreeView const & tree, Loader* loader) {
248298 bst_node_t nidx = 0 ;
@@ -257,8 +307,7 @@ __device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const& tree, Loader*
257307}
258308
259309template <bool has_missing, typename Loader>
260- __device__ float GetLeafWeight (bst_idx_t ridx, TreeView const &tree,
261- Loader *loader) {
310+ __device__ float GetLeafWeight (bst_idx_t ridx, TreeView const & tree, Loader* loader) {
262311 bst_node_t nidx = -1 ;
263312 if (tree.HasCategoricalSplit ()) {
264313 nidx = GetLeafIndex<has_missing, true >(ridx, tree, loader);
@@ -267,6 +316,7 @@ __device__ float GetLeafWeight(bst_idx_t ridx, TreeView const &tree,
267316 }
268317 return tree.d_tree [nidx].LeafValue ();
269318}
319+ } // namespace scalar
270320
271321template <typename Loader, typename Data, bool has_missing, typename EncAccessor>
272322__global__ void
@@ -295,9 +345,9 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
295345
296346 bst_node_t leaf = -1 ;
297347 if (d_tree.HasCategoricalSplit ()) {
298- leaf = GetLeafIndex<has_missing, true >(ridx, d_tree, &loader);
348+ leaf = scalar:: GetLeafIndex<has_missing, true >(ridx, d_tree, &loader);
299349 } else {
300- leaf = GetLeafIndex<has_missing, false >(ridx, d_tree, &loader);
350+ leaf = scalar:: GetLeafIndex<has_missing, false >(ridx, d_tree, &loader);
301351 }
302352 d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
303353 }
@@ -313,7 +363,7 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
313363 common::Span<uint32_t const > d_cat_tree_segments,
314364 common::Span<RegTree::CategoricalSplitMatrix::Segment const > d_cat_node_segments,
315365 common::Span<uint32_t const > d_categories, bst_tree_t tree_begin,
316- bst_tree_t tree_end, bst_feature_t num_features, size_t num_rows,
366+ bst_tree_t tree_end, bst_feature_t num_features, bst_idx_t num_rows,
317367 bool use_shared, int num_group, float missing, EncAccessor acc) {
318368 bst_uint global_idx = blockDim .x * blockIdx .x + threadIdx .x ;
319369 Loader loader{std::move (data), use_shared, num_features, num_rows, missing, std::move (acc)};
@@ -326,20 +376,19 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
326376 tree_begin, tree_idx, d_nodes,
327377 d_tree_segments, d_tree_split_types, d_cat_tree_segments,
328378 d_cat_node_segments, d_categories};
329- float leaf = GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
379+ float leaf = scalar:: GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
330380 sum += leaf;
331381 }
332382 d_out_predictions[global_idx] += sum;
333383 } else {
334384 for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
335385 int tree_group = d_tree_group[tree_idx];
336- TreeView d_tree{
337- tree_begin, tree_idx, d_nodes,
338- d_tree_segments, d_tree_split_types, d_cat_tree_segments,
339- d_cat_node_segments, d_categories};
386+ TreeView d_tree{tree_begin, tree_idx, d_nodes,
387+ d_tree_segments, d_tree_split_types, d_cat_tree_segments,
388+ d_cat_node_segments, d_categories};
340389 bst_uint out_prediction_idx = global_idx * num_group + tree_group;
341390 d_out_predictions[out_prediction_idx] +=
342- GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
391+ scalar:: GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
343392 }
344393 }
345394}
@@ -400,12 +449,12 @@ class DeviceModel {
400449 for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
401450 auto & src_nodes = model.trees .at (tree_idx)->GetNodes ();
402451 auto & src_stats = model.trees .at (tree_idx)->GetStats ();
403- dh::safe_cuda (cudaMemcpyAsync (
404- d_nodes + h_tree_segments[tree_idx - tree_begin], src_nodes.data (),
405- sizeof (RegTree::Node) * src_nodes. size (), cudaMemcpyDefault));
406- dh::safe_cuda (cudaMemcpyAsync (
407- d_stats + h_tree_segments[tree_idx - tree_begin], src_stats.data (),
408- sizeof (RTreeNodeStat) * src_stats. size (), cudaMemcpyDefault));
452+ dh::safe_cuda (cudaMemcpyAsync (d_nodes + h_tree_segments[tree_idx - tree_begin],
453+ src_nodes.data (), sizeof (RegTree::Node) * src_nodes. size (),
454+ cudaMemcpyDefault));
455+ dh::safe_cuda (cudaMemcpyAsync (d_stats + h_tree_segments[tree_idx - tree_begin],
456+ src_stats.data (), sizeof (RTreeNodeStat) * src_stats. size (),
457+ cudaMemcpyDefault));
409458 }
410459
411460 tree_group = HostDeviceVector<int >(model.tree_info .size (), 0 , device);
@@ -424,14 +473,13 @@ class DeviceModel {
424473
425474 categories = HostDeviceVector<uint32_t >({}, device);
426475 categories_tree_segments = HostDeviceVector<uint32_t >(1 , 0 , device);
427- std::vector<uint32_t > & h_categories = categories.HostVector ();
428- std::vector<uint32_t > & h_split_cat_segments = categories_tree_segments.HostVector ();
476+ std::vector<uint32_t >& h_categories = categories.HostVector ();
477+ std::vector<uint32_t >& h_split_cat_segments = categories_tree_segments.HostVector ();
429478 for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
430479 auto const & src_cats = model.trees .at (tree_idx)->GetSplitCategories ();
431480 size_t orig_size = h_categories.size ();
432481 h_categories.resize (orig_size + src_cats.size ());
433- std::copy (src_cats.cbegin (), src_cats.cend (),
434- h_categories.begin () + orig_size);
482+ std::copy (src_cats.cbegin (), src_cats.cend (), h_categories.begin () + orig_size);
435483 h_split_cat_segments.push_back (h_categories.size ());
436484 }
437485
@@ -974,7 +1022,7 @@ class LaunchConfig {
9741022 void LaunchPredict (Context const * ctx, Data data, float missing, bst_idx_t n_samples,
9751023 bst_feature_t n_features, DeviceModel const & model, bool is_dense,
9761024 enc::DeviceColumnsView const & new_enc, bst_idx_t batch_offset,
977- HostDeviceVector<bst_float >* predictions) const {
1025+ HostDeviceVector<float >* predictions) const {
9781026 LaunchPredictKernel (ctx, is_dense, new_enc, model, [&](auto is_dense, auto && acc) {
9791027 constexpr bool kHasMissing = !std::is_same_v<decltype (is_dense), std::true_type>;
9801028 using EncAccessor = std::remove_reference_t <decltype (acc)>;
@@ -993,10 +1041,30 @@ class LaunchConfig {
9931041 });
9941042 }
9951043
1044+ template <template <typename > typename Loader, typename Data>
1045+ void LaunchMultiPredict (Context const * ctx, Data data, gbm::GBTreeModel const & model,
1046+ float missing, bst_tree_t tree_begin, bst_tree_t tree_end,
1047+ bst_idx_t batch_offset, HostDeviceVector<float >* predictions) const {
1048+ CHECK_EQ (batch_offset, 0 ); // External memory is not supported yet.
1049+ CHECK_GT (tree_end, tree_begin);
1050+ std::vector<MultiTargetTreeView> h_trees;
1051+ for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
1052+ h_trees.emplace_back (model.trees [tree_idx]->GetMultiTargetTree ()->View (ctx));
1053+ }
1054+ dh::device_vector<MultiTargetTreeView> trees = h_trees;
1055+ CHECK_GE (predictions->Size (), data.NumRows () * h_trees.front ().NumTargets ());
1056+ auto kernel = multi::PredictKernel<Loader<NoOpAccessor>, Data, true , NoOpAccessor>;
1057+ auto predt =
1058+ linalg::MakeTensorView (ctx, predictions, data.NumRows (), h_trees.front ().NumTargets ());
1059+ this ->Grid (data.NumRows ())
1060+ .LaunchImpl (std::move (kernel), std::move (data), dh::ToSpan (trees), this ->UseShared (),
1061+ missing, predt, NoOpAccessor{});
1062+ }
1063+
9961064 template <template <typename > typename Loader, typename Data>
9971065 void LaunchLeaf (Context const * ctx, Data data, bst_idx_t n_samples, bst_feature_t n_features,
9981066 DeviceModel const & model, bool is_dense, enc::DeviceColumnsView const & new_enc,
999- bst_idx_t batch_offset, HostDeviceVector<bst_float >* predictions) const {
1067+ bst_idx_t batch_offset, HostDeviceVector<float >* predictions) const {
10001068 LaunchPredictKernel (ctx, is_dense, new_enc, model, [&](auto is_dense, auto && acc) {
10011069 constexpr bool kHasMissing = !std::is_same_v<decltype (is_dense), std::true_type>;
10021070 using EncAccessor = std::remove_reference_t <decltype (acc)>;
@@ -1037,7 +1105,9 @@ class GPUPredictor : public xgboost::Predictor {
10371105 out_preds->SetDevice (ctx_->Device ());
10381106 auto const & info = p_fmat->Info ();
10391107 DeviceModel d_model;
1040- d_model.Init (model, tree_begin, tree_end, ctx_->Device ());
1108+ if (!model.trees [tree_begin]->IsMultiTarget ()) {
1109+ d_model.Init (model, tree_begin, tree_end, ctx_->Device ());
1110+ }
10411111
10421112 if (info.IsColumnSplit ()) {
10431113 column_split_helper_.PredictBatch (p_fmat, out_preds, model, d_model);
@@ -1056,9 +1126,15 @@ class GPUPredictor : public xgboost::Predictor {
10561126 auto n_features = model.learner_model_param ->num_feature ;
10571127 LaunchConfig cfg{ctx_, n_features};
10581128 SparsePageView data (page.data .DeviceSpan (), page.offset .DeviceSpan (), n_features);
1059- cfg.LaunchPredict <SparsePageLoader>(
1060- this ->ctx_ , std::move (data), std::numeric_limits<float >::quiet_NaN (), page.Size (),
1061- n_features, d_model, p_fmat->IsDense (), new_enc, batch_offset, out_preds);
1129+ if (model.trees [tree_begin]->IsMultiTarget ()) {
1130+ cfg.LaunchMultiPredict <SparsePageLoader>(this ->ctx_ , std::move (data), model,
1131+ std::numeric_limits<float >::quiet_NaN (),
1132+ tree_begin, tree_end, batch_offset, out_preds);
1133+ } else {
1134+ cfg.LaunchPredict <SparsePageLoader>(
1135+ this ->ctx_ , std::move (data), std::numeric_limits<float >::quiet_NaN (), page.Size (),
1136+ n_features, d_model, p_fmat->IsDense (), new_enc, batch_offset, out_preds);
1137+ }
10621138 batch_offset += page.Size () * model.learner_model_param ->OutputLength ();
10631139 }
10641140 } else {
@@ -1158,7 +1234,7 @@ class GPUPredictor : public xgboost::Predictor {
11581234
11591235 void PredictContribution (DMatrix* p_fmat, HostDeviceVector<float >* out_contribs,
11601236 const gbm::GBTreeModel& model, bst_tree_t tree_end,
1161- std::vector<bst_float > const * tree_weights, bool approximate, int ,
1237+ std::vector<float > const * tree_weights, bool approximate, int ,
11621238 unsigned ) const override {
11631239 StringView not_implemented{
11641240 " contribution is not implemented in the GPU predictor, use CPU instead." };
@@ -1177,8 +1253,7 @@ class GPUPredictor : public xgboost::Predictor {
11771253 const int ngroup = model.learner_model_param ->num_output_group ;
11781254 CHECK_NE (ngroup, 0 );
11791255 // allocate space for (number of features + bias) times the number of rows
1180- size_t contributions_columns =
1181- model.learner_model_param ->num_feature + 1 ; // +1 for bias
1256+ size_t contributions_columns = model.learner_model_param ->num_feature + 1 ; // +1 for bias
11821257 auto dim_size = contributions_columns * model.learner_model_param ->num_output_group ;
11831258 out_contribs->Resize (p_fmat->Info ().num_row_ * dim_size);
11841259 out_contribs->Fill (0 .0f );
@@ -1245,8 +1320,8 @@ class GPUPredictor : public xgboost::Predictor {
12451320 gbm::GBTreeModel const & model, bst_tree_t tree_end,
12461321 std::vector<float > const * tree_weights,
12471322 bool approximate) const override {
1248- std::string not_implemented{" contribution is not implemented in GPU "
1249- " predictor, use `cpu_predictor` instead." };
1323+ std::string not_implemented{
1324+ " contribution is not implemented in GPU predictor, use cpu instead." };
12501325 if (approximate) {
12511326 LOG (FATAL) << " Approximated " << not_implemented;
12521327 }
@@ -1333,7 +1408,6 @@ class GPUPredictor : public xgboost::Predictor {
13331408 gbm::GBTreeModel const & model, bst_tree_t tree_end) const override {
13341409 dh::safe_cuda (cudaSetDevice (ctx_->Ordinal ()));
13351410
1336-
13371411 const MetaInfo& info = p_fmat->Info ();
13381412 bst_idx_t n_samples = info.num_row_ ;
13391413 tree_end = GetTreeLimit (model.trees , tree_end);
0 commit comments