@@ -145,6 +145,21 @@ void GBTree::Configure(Args const& cfg) {
145145 }
146146}
147147
148+ void GBTreeModel::InitTreesToUpdate () {
149+ if (trees_to_update.empty ()) {
150+ for (auto & tree : trees) {
151+ trees_to_update.push_back (std::move (tree));
152+ }
153+
154+ trees.clear ();
155+ param.num_trees = 0 ;
156+ tree_info.HostVector ().clear ();
157+
158+ iteration_indptr.clear ();
159+ iteration_indptr.push_back (0 );
160+ }
161+ }
162+
148163void GPUCopyGradient (Context const *, linalg::Matrix<GradientPair> const *, bst_group_t ,
149164 linalg::Matrix<GradientPair>*)
150165#if defined(XGBOOST_USE_CUDA)
@@ -444,7 +459,9 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
444459
445460 auto & out_indptr = out_model.iteration_indptr ;
446461 TreesOneGroup& out_trees = out_model.trees ;
447- std::vector<int32_t >& out_trees_info = out_model.tree_info ;
462+ auto & out_tree_info = out_model.tree_info .HostVector ();
463+
464+ auto const & in_tree_info = this ->model_ .tree_info .ConstHostVector ();
448465
449466 bst_layer_t n_layers = (end - begin) / step;
450467 out_indptr.resize (n_layers + 1 , 0 );
@@ -462,8 +479,8 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
462479 std::unique_ptr<RegTree> new_tree{this ->model_ .trees .at (in_tree_idx)->Copy ()};
463480 out_trees.emplace_back (std::move (new_tree));
464481
465- bst_group_t group = this -> model_ . tree_info [in_tree_idx];
466- out_trees_info .push_back (group);
482+ bst_group_t group = in_tree_info [in_tree_idx];
483+ out_tree_info .push_back (group);
467484
468485 out_model.iteration_indptr [out_l + 1 ]++;
469486 });
@@ -735,7 +752,7 @@ class Dart : public GBTree {
735752 auto layer_trees = [&]() {
736753 return model_.param .num_parallel_tree * model_.learner_model_param ->OutputLength ();
737754 };
738-
755+ auto const & h_tree_info = this -> model_ . tree_info . ConstHostVector ();
739756 for (bst_tree_t i = tree_begin; i < tree_end; i += 1 ) {
740757 if (training && std::binary_search (idx_drop_.cbegin (), idx_drop_.cend (), i)) {
741758 continue ;
@@ -749,7 +766,7 @@ class Dart : public GBTree {
749766
750767 // Multiple the weight to output prediction.
751768 auto w = this ->weight_drop_ .at (i);
752- auto group = model_. tree_info .at (i);
769+ auto group = h_tree_info .at (i);
753770 CHECK_EQ (p_out_preds->predictions .Size (), predts.predictions .Size ());
754771
755772 size_t n_rows = p_fmat->Info ().num_row_ ;
@@ -815,6 +832,7 @@ class Dart : public GBTree {
815832 CHECK (success) << msg;
816833 };
817834
835+ auto const & h_tree_info = this ->model_ .tree_info .ConstHostVector ();
818836 // Inplace predict is not used for training, so no need to drop tree.
819837 for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
820838 predict_impl (i);
@@ -837,7 +855,7 @@ class Dart : public GBTree {
837855 }
838856 // Multiple the tree weight
839857 auto w = this ->weight_drop_ .at (i);
840- auto group = model_. tree_info .at (i);
858+ auto group = h_tree_info .at (i);
841859 CHECK_EQ (predts.predictions .Size (), p_out_preds->predictions .Size ());
842860
843861 size_t n_rows = p_fmat->Info ().num_row_ ;
0 commit comments