@@ -174,21 +174,24 @@ void SortPositionBatch(Context const* ctx, common::Span<const PerNodeData<OpData
174174 auto ret =
175175 cub::DispatchScan<decltype (input_iterator), decltype (discard_write_iterator), IndexFlagOp,
176176 cub::NullType, std::uint64_t >::Dispatch (nullptr , n_bytes, input_iterator,
177- discard_write_iterator,
178- IndexFlagOp{}, cub::NullType{},
179- static_cast <std::uint64_t >(total_rows),
180- ctx->CUDACtx ()->Stream ());
177+ discard_write_iterator,
178+ IndexFlagOp{}, cub::NullType{},
179+ static_cast <std::uint64_t >(
180+ total_rows),
181+ ctx->CUDACtx ()->Stream ());
181182 dh::safe_cuda (ret);
182183 tmp->resize (n_bytes);
183184 }
184185 n_bytes = tmp->size ();
185186 auto ret =
186187 cub::DispatchScan<decltype (input_iterator), decltype (discard_write_iterator), IndexFlagOp,
187- cub::NullType, std::uint64_t >::Dispatch (tmp->data (), n_bytes, input_iterator,
188- discard_write_iterator,
189- IndexFlagOp{}, cub::NullType{},
190- static_cast <std::uint64_t >(total_rows),
191- ctx->CUDACtx ()->Stream ());
188+ cub::NullType, std::uint64_t >::Dispatch (tmp->data (), n_bytes,
189+ input_iterator,
190+ discard_write_iterator,
191+ IndexFlagOp{}, cub::NullType{},
192+ static_cast <std::uint64_t >(
193+ total_rows),
194+ ctx->CUDACtx ()->Stream ());
192195 dh::safe_cuda (ret);
193196
194197 constexpr int kBlockSize = 256 ;
@@ -272,8 +275,6 @@ class RowPartitioner {
272275 * rows idx | 3, 5, 1 | 13, 31 |
273276 */
274277 dh::DeviceUVector<RowIndexT> ridx_;
275- // Staging area for sorting ridx
276- dh::DeviceUVector<RowIndexT> ridx_tmp_;
277278 dh::DeviceUVector<int8_t > tmp_;
278279 dh::PinnedMemory pinned_;
279280 dh::PinnedMemory pinned2_;
@@ -343,7 +344,8 @@ class RowPartitioner {
343344 void UpdatePositionBatch (Context const * ctx, std::vector<bst_node_t > const & nidx,
344345 std::vector<bst_node_t > const & left_nidx,
345346 std::vector<bst_node_t > const & right_nidx,
346- std::vector<OpDataT> const & op_data, UpdatePositionOpT op) {
347+ std::vector<OpDataT> const & op_data, common::Span<RowIndexT> ridx_tmp,
348+ UpdatePositionOpT op) {
347349 if (nidx.empty ()) {
348350 return ;
349351 }
@@ -366,20 +368,21 @@ class RowPartitioner {
366368 auto h_counts = pinned_.GetSpan <RowIndexT>(nidx.size ());
367369 // Must initialize with 0 as 0 count is not written in the kernel.
368370 dh::TemporaryArray<RowIndexT> d_counts (nidx.size (), 0 );
371+ CHECK_EQ (ridx_tmp.size (), this ->Size ());
369372
370373 // Process a sub-batch
371- auto sub_batch_impl = [ctx, op, this ](common::Span<bst_node_t const > nidx,
372- common::Span<PerNodeData<OpDataT>> d_batch_info,
373- common::Span<RowIndexT> d_counts) {
374+ auto sub_batch_impl = [& ](common::Span<bst_node_t const > nidx,
375+ common::Span<PerNodeData<OpDataT>> d_batch_info,
376+ common::Span<RowIndexT> d_counts) {
374377 std::size_t total_rows = 0 ;
375378 for (bst_node_t i : nidx) {
376379 total_rows += this ->ridx_segments_ [i].segment .Size ();
377380 }
378381
379382 // Partition the rows according to the operator
380383 SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, d_batch_info, dh::ToSpan (this ->ridx_ ),
381- dh::ToSpan ( this -> ridx_tmp_ ) , d_counts,
382- total_rows, op, &this ->tmp_ );
384+ ridx_tmp , d_counts, total_rows, op ,
385+ &this ->tmp_ );
383386 };
384387
385388 // Divide inputs into sub-batches.
@@ -441,4 +444,59 @@ class RowPartitioner {
441444 base_ridx, d_ridx, d_out_position, op);
442445 }
443446};
447+
448+ // Partitioner for all batches, used for external memory training.
449+ class RowPartitionerBatches {
450+ private:
451+ // Temporary buffer for sorting the samples.
452+ dh::DeviceUVector<cuda_impl::RowIndexT> ridx_tmp_;
453+ // Partitioners for each batch.
454+ std::vector<std::unique_ptr<RowPartitioner>> partitioners_;
455+
456+ public:
457+ void Reset (Context const * ctx, std::vector<bst_idx_t > const & batch_ptr) {
458+ CHECK_GE (batch_ptr.size (), 2 );
459+ std::size_t n_batches = batch_ptr.size () - 1 ;
460+ if (partitioners_.size () != n_batches) {
461+ partitioners_.clear ();
462+ }
463+
464+ bst_idx_t n_max_samples = 0 ;
465+ for (std::size_t k = 0 ; k < n_batches; ++k) {
466+ if (partitioners_.size () != n_batches) {
467+ // First run.
468+ partitioners_.emplace_back (std::make_unique<RowPartitioner>());
469+ }
470+ auto base_ridx = batch_ptr[k];
471+ auto n_samples = batch_ptr.at (k + 1 ) - base_ridx;
472+ partitioners_[k]->Reset (ctx, n_samples, base_ridx);
473+ CHECK_LE (n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max ());
474+ n_max_samples = std::max (n_samples, n_max_samples);
475+ }
476+ this ->ridx_tmp_ .resize (n_max_samples);
477+ }
478+
479+ // Accessors
480+ [[nodiscard]] decltype (auto ) operator [](std::size_t i) { return partitioners_[i]; }
481+ decltype (auto ) At(std::size_t i) { return partitioners_.at (i); }
482+ [[nodiscard]] std::size_t Size () const { return this ->partitioners_ .size (); }
483+ decltype (auto ) cbegin() const { return this ->partitioners_ .cbegin (); } // NOLINT
484+ decltype (auto ) cend() const { return this ->partitioners_ .cend (); } // NOLINT
485+ decltype (auto ) begin() const { return this ->partitioners_ .cbegin (); } // NOLINT
486+ decltype (auto ) end() const { return this ->partitioners_ .cend (); } // NOLINT
487+
488+ [[nodiscard]] decltype (auto ) Front() { return this ->partitioners_ .front (); }
489+ [[nodiscard]] bool Empty () const { return this ->partitioners_ .empty (); }
490+
491+ template <typename UpdatePositionOpT, typename OpDataT>
492+ void UpdatePositionBatch (Context const * ctx, std::int32_t batch_idx,
493+ std::vector<bst_node_t > const & nidx,
494+ std::vector<bst_node_t > const & left_nidx,
495+ std::vector<bst_node_t > const & right_nidx,
496+ std::vector<OpDataT> const & op_data, UpdatePositionOpT op) {
497+ auto & part = this ->At (batch_idx);
498+ auto ridx_tmp = dh::ToSpan (this ->ridx_tmp_ ).subspan (0 , part->Size ());
499+ part->UpdatePositionBatch (ctx, nidx, left_nidx, right_nidx, op_data, ridx_tmp, op);
500+ }
501+ };
444502}; // namespace xgboost::tree
0 commit comments