@@ -294,19 +294,25 @@ template <auto GlobalDense = SharedMemHistKernel<true, false, kBlockThreads, kIt
294294 auto SharedDense = SharedMemHistKernel<true , true , kBlockThreads , kItemsPerThread >,
295295 auto Shared = SharedMemHistKernel<false , true , kBlockThreads , kItemsPerThread >>
296296struct HistogramKernel {
297+ enum KernelType : std::size_t {
298+ kGlobalDense = 0 ,
299+ kGlobal = 1 ,
300+ kSharedDense = 2 ,
301+ kShared = 3 ,
302+ };
297303 // Kernel for working with dense Ellpack using the global memory.
298- decltype (Global ) global_dense_kernel{
304+ decltype (GlobalDense ) global_dense_kernel{
299305 SharedMemHistKernel<true , false , kBlockThreads , kItemsPerThread >};
300306 // Kernel for working with sparse Ellpack using the global memory.
301307 decltype (Global) global_kernel{SharedMemHistKernel<false , false , kBlockThreads , kItemsPerThread >};
302308 // Kernel for working with dense Ellpack using the shared memory.
303- decltype (Shared ) shared_dense_kernel{
309+ decltype (SharedDense ) shared_dense_kernel{
304310 SharedMemHistKernel<true , true , kBlockThreads , kItemsPerThread >};
305311 // Kernel for working with sparse Ellpack using the shared memory.
306312 decltype (Shared) shared_kernel{SharedMemHistKernel<false , true , kBlockThreads , kItemsPerThread >};
307313
308314 bool shared{false };
309- std::uint32_t grid_size{ 0 };
315+ std::array<std:: uint32_t , 4 > grid_sizes{ 0 , 0 , 0 , 0 };
310316 std::size_t smem_size{0 };
311317 bool const force_global;
312318
@@ -321,7 +327,7 @@ struct HistogramKernel {
321327 this ->shared = !force_global_memory && this ->smem_size <= max_shared_memory;
322328 this ->smem_size = this ->shared ? this ->smem_size : 0 ;
323329
324- auto init = [&](auto & kernel) {
330+ auto init = [&](auto & kernel, KernelType k ) {
325331 if (this ->shared ) {
326332 dh::safe_cuda (cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
327333 max_shared_memory));
@@ -338,11 +344,14 @@ struct HistogramKernel {
338344
339345 // This gives the number of blocks to keep the device occupied Use this as the
340346 // maximum number of blocks
341- this ->grid_size = n_blocks_per_mp * n_mps;
347+ this ->grid_sizes [ static_cast <std:: size_t >(k)] = n_blocks_per_mp * n_mps;
342348 };
343349 // Initialize all kernel instantiations
350+ std::array kernel_types{kGlobalDense , kGlobal , kSharedDense , kShared };
351+ std::int32_t k = 0 ;
344352 for (auto & kernel : {global_dense_kernel, global_kernel, shared_dense_kernel, shared_kernel}) {
345- init (kernel);
353+ init (kernel, kernel_types[k]);
354+ ++k;
346355 }
347356 }
348357};
@@ -375,29 +384,35 @@ class DeviceHistogramBuilderImpl {
375384 // Allocate number of blocks such that each block has about kMinItemsPerBlock work
376385 // Up to a maximum where the device is saturated
377386 auto constexpr kMinItemsPerBlock = ItemsPerTile ();
378- auto grid_size = std::min (kernel_->grid_size , static_cast <std::uint32_t >(common::DivRoundUp (
379- items_per_group, kMinItemsPerBlock )));
380- auto launcher = [&](auto kernel) {
387+
388+ auto launcher = [&](auto const & kernel, std::uint32_t grid_size) {
389+ CHECK_NE (grid_size, 0 );
390+ grid_size = std::min (grid_size, static_cast <std::uint32_t >(
391+ common::DivRoundUp (items_per_group, kMinItemsPerBlock )));
381392 dh::LaunchKernel{dim3 (grid_size, feature_groups.NumGroups ()), // NOLINT
382393 static_cast <uint32_t >(kBlockThreads ), kernel_->smem_size , ctx->Stream ()}(
383394 kernel, matrix, feature_groups, d_ridx, histogram.data (), gpair.data (), rounding);
384395 };
385396
397+ using K = HistogramKernel<>::KernelType;
386398 if (!this ->kernel_ ->shared ) { // Use global memory
387399 CHECK_EQ (this ->kernel_ ->smem_size , 0 );
388400 if (matrix.IsDenseCompressed ()) {
389401 // Dense must use shared memory except for testing.
390402 CHECK (this ->kernel_ ->force_global );
391- launcher (this ->kernel_ ->global_dense_kernel );
403+ launcher (this ->kernel_ ->global_dense_kernel , this -> kernel_ -> grid_sizes [K:: kGlobalDense ] );
392404 } else {
393- launcher (this ->kernel_ ->global_kernel );
405+ // Sparse
406+ launcher (this ->kernel_ ->global_kernel , this ->kernel_ ->grid_sizes [K::kGlobal ]);
394407 }
395408 } else { // Use shared memory
396409 CHECK_NE (this ->kernel_ ->smem_size , 0 );
397410 if (matrix.IsDenseCompressed ()) {
398- launcher (this ->kernel_ ->shared_dense_kernel );
411+ // Dense
412+ launcher (this ->kernel_ ->shared_dense_kernel , this ->kernel_ ->grid_sizes [K::kSharedDense ]);
399413 } else {
400- launcher (this ->kernel_ ->shared_kernel );
414+ // Sparse
415+ launcher (this ->kernel_ ->shared_kernel , this ->kernel_ ->grid_sizes [K::kShared ]);
401416 }
402417 }
403418 }
0 commit comments