@@ -228,7 +228,8 @@ template <typename Partitioner>
228228common::BlockedSpace2d ConstructHistSpace (Partitioner const &partitioners,
229229 std::vector<bst_node_t > const &nodes_to_build,
230230 const GHistIndexMatrix &gidx,
231- std::size_t l1_size, bool read_by_column) {
231+ std::size_t l1_size, bst_bin_t max_bin,
232+ bool read_by_column) {
232233 // FIXME(jiamingy): Handle different size of space. Right now we use the maximum
233234 // partition size for the buffer, which might not be efficient if partition sizes
234235 // has significant variance.
@@ -262,12 +263,9 @@ common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
262263 */
263264
264265 /* First step: determine whether one histogram column fits into L1.
265- * The maximum number of bins in a column is 2^8, 2^16, or 2^32,
266- * depending on the bin index size.
267266 * Note: column-wise kernel is used for dense data only.
268267 */
269- std::size_t max_elem_in_hist_col = 1u << (8 * gidx.index .GetBinTypeSize ());
270- std::size_t hist_col_size = 2 * sizeof (double ) * max_elem_in_hist_col;
268+ std::size_t hist_col_size = 2 * sizeof (double ) * max_bin;
271269 bool hist_col_fit_to_l1 = hist_col_size < usable_l1_size;
272270
273271 /* Second step: compute available L1 space for row data. */
@@ -369,7 +367,7 @@ class MultiHistogramBuilder {
369367 bool read_by_column = ReadByColumn (gidx, force_read_by_column);
370368
371369 auto space = ConstructHistSpace (partitioners, nodes, gidx,
372- cache_manager_.L1Size (), read_by_column);
370+ cache_manager_.L1Size (), param. max_bin , read_by_column);
373371 for (bst_target_t t{0 }; t < n_targets; ++t) {
374372 auto t_gpair = gpair.Slice (linalg::All (), t);
375373 this ->target_builders_ [t].BuildHist (page_idx, space, gidx,
@@ -411,7 +409,7 @@ class MultiHistogramBuilder {
411409 bool read_by_column = ReadByColumn (page, force_read_by_column);
412410
413411 auto space = ConstructHistSpace (partitioners, nodes_to_build, page,
414- cache_manager_.L1Size (), read_by_column);
412+ cache_manager_.L1Size (), param. max_bin , read_by_column);
415413
416414 CHECK_EQ (gpair.Shape (1 ), tree.NumTargets ());
417415 for (bst_target_t t = 0 ; t < tree.NumTargets (); ++t) {
0 commit comments