@@ -222,7 +222,7 @@ template <typename ValueType>
222
222
__global__ __launch_bounds__ (default_block_size) void set_zero (
223
223
const size_type nnz, ValueType * __restrict__ val)
224
224
{
225
- const auto ind = size_type (blockDim.x) * blockIdx.x + threadIdx.x ;
225
+ const auto ind = thread :: get_thread_id_flat () ;
226
226
if (ind < nnz) {
227
227
val[ind] = zero< ValueType> ();
228
228
}
@@ -438,19 +438,19 @@ __device__ void device_classical_spmv(const size_type num_rows,
438
438
ValueType * __restrict__ c,
439
439
const size_type c_stride, Closure scale)
440
440
{
441
- const auto tid = size_type (blockDim.x) * blockIdx.x + threadIdx.x;
442
- const auto subrow = size_type (gridDim.x) * blockDim.x / subwarp_size;
443
- const auto subid = tid % subwarp_size;
441
+ auto subwarp_tile =
442
+ group ::tiled_partition< subwarp_size> (group ::this_thread_block ());
443
+ const auto subrow = thread ::get_subwarp_num_flat< subwarp_size> ();
444
+ const auto subid = subwarp_tile.thread_rank ();
444
445
const auto column_id = blockIdx.y;
445
- for (auto row = tid / subwarp_size; row < num_rows; row += subrow) {
446
+ auto row = thread ::get_subwarp_id_flat< subwarp_size> ();
447
+ for (; row < num_rows; row += subrow) {
446
448
const auto ind_end = row_ptrs[row + 1 ];
447
449
ValueType temp_val = zero< ValueType> ();
448
450
for (auto ind = row_ptrs[row] + subid; ind < ind_end;
449
451
ind += subwarp_size) {
450
452
temp_val += val[ind] * b[col_idxs[ind] * b_stride + column_id];
451
453
}
452
- auto subwarp_tile =
453
- group ::tiled_partition< subwarp_size> (group ::this_thread_block ());
454
454
auto subwarp_result = reduce (
455
455
subwarp_tile, temp_val,
456
456
[](const ValueType & a, const ValueType & b) { return a + b; });
@@ -500,8 +500,7 @@ __global__ __launch_bounds__(default_block_size) void spgeam_nnz(
500
500
const IndexType * b_row_ptrs, const IndexType * b_col_idxs,
501
501
IndexType num_rows, IndexType * nnz)
502
502
{
503
- auto row = (threadIdx.x + blockDim.x * static_cast< size_type> (blockIdx.x)) /
504
- subwarp_size;
503
+ auto row = thread ::get_subwarp_id_flat< subwarp_size, IndexType> ();
505
504
auto subwarp =
506
505
group ::tiled_partition< subwarp_size> (group ::this_thread_block ());
507
506
if (row >= num_rows) {
@@ -533,8 +532,7 @@ __global__ __launch_bounds__(default_block_size) void spgeam(
533
532
const IndexType * b_col_idxs, const ValueType * b_vals, IndexType num_rows,
534
533
const IndexType * c_row_ptrs, IndexType * c_col_idxs, ValueType * c_vals)
535
534
{
536
- auto row = (threadIdx.x + blockDim.x * static_cast< size_type> (blockIdx.x)) /
537
- subwarp_size;
535
+ auto row = thread ::get_subwarp_id_flat< subwarp_size, IndexType> ();
538
536
auto subwarp =
539
537
group ::tiled_partition< subwarp_size> (group ::this_thread_block ());
540
538
if (row >= num_rows) {
@@ -591,7 +589,7 @@ __global__ __launch_bounds__(default_block_size) void convert_row_ptrs_to_idxs(
591
589
size_type num_rows, const IndexType * __restrict__ ptrs,
592
590
IndexType * __restrict__ idxs)
593
591
{
594
- const auto tidx = threadIdx.x + blockDim.x * blockIdx.x ;
592
+ const auto tidx = thread :: get_thread_id_flat () ;
595
593
if (tidx < num_rows) {
596
594
for (auto i = ptrs[tidx]; i < ptrs[tidx + 1 ]; i++ ) {
597
595
idxs[i] = tidx;
@@ -620,7 +618,7 @@ __global__ __launch_bounds__(default_block_size) void fill_in_dense(
620
618
const ValueType * __restrict__ values, size_type stride,
621
619
ValueType * __restrict__ result)
622
620
{
623
- const auto tidx = threadIdx.x + blockDim.x * blockIdx.x ;
621
+ const auto tidx = thread :: get_thread_id_flat () ;
624
622
if (tidx < num_rows) {
625
623
for (auto i = row_ptrs[tidx]; i < row_ptrs[tidx + 1 ]; i++ ) {
626
624
result[stride * tidx + col_idxs[i]] = values[i];
@@ -634,7 +632,7 @@ __global__ __launch_bounds__(default_block_size) void calculate_nnz_per_row(
634
632
size_type num_rows, const IndexType * __restrict__ row_ptrs,
635
633
size_type * __restrict__ nnz_per_row)
636
634
{
637
- const auto tidx = threadIdx.x + blockIdx.x * blockDim.x ;
635
+ const auto tidx = thread :: get_thread_id_flat () ;
638
636
if (tidx < num_rows) {
639
637
nnz_per_row[tidx] = row_ptrs[tidx + 1 ] - row_ptrs[tidx];
640
638
}
@@ -685,7 +683,7 @@ __global__ __launch_bounds__(default_block_size) void fill_in_sellp(
685
683
IndexType * __restrict__ result_col_idxs,
686
684
ValueType * __restrict__ result_values)
687
685
{
688
- const auto global_row = threadIdx.x + blockIdx.x * blockDim.x ;
686
+ const auto global_row = thread :: get_thread_id_flat () ;
689
687
const auto row = global_row % slice_size;
690
688
const auto sliceid = global_row / slice_size;
691
689
@@ -714,7 +712,7 @@ __global__ __launch_bounds__(default_block_size) void initialize_zero_ell(
714
712
size_type max_nnz_per_row, size_type stride, ValueType * __restrict__ values,
715
713
IndexType * __restrict__ col_idxs)
716
714
{
717
- const auto tidx = threadIdx.x + blockIdx.x * blockDim.x ;
715
+ const auto tidx = thread :: get_thread_id_flat () ;
718
716
719
717
if (tidx < stride * max_nnz_per_row) {
720
718
values[tidx] = zero< ValueType> ();
@@ -732,10 +730,9 @@ __global__ __launch_bounds__(default_block_size) void fill_in_ell(
732
730
ValueType * __restrict__ result_values,
733
731
IndexType * __restrict__ result_col_idxs)
734
732
{
735
- const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
736
733
constexpr auto warp_size = config ::warp_size;
737
- const auto row = tidx / warp_size;
738
- const auto local_tidx = tidx % warp_size;
734
+ const auto row = thread ::get_subwarp_id_flat < warp_size> () ;
735
+ const auto local_tidx = threadIdx.x % warp_size;
739
736
740
737
if (row < num_rows) {
741
738
for (size_type i = local_tidx;
@@ -754,10 +751,11 @@ __global__ __launch_bounds__(default_block_size) void reduce_max_nnz_per_slice(
754
751
size_type num_rows, size_type slice_size, size_type stride_factor,
755
752
const size_type * __restrict__ nnz_per_row, size_type * __restrict__ result)
756
753
{
757
- const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
758
754
constexpr auto warp_size = config ::warp_size;
759
- const auto warpid = tidx / warp_size;
760
- const auto tid_in_warp = tidx % warp_size;
755
+ auto warp_tile =
756
+ group ::tiled_partition< warp_size> (group ::this_thread_block ());
757
+ const auto warpid = thread ::get_subwarp_id_flat< warp_size> ();
758
+ const auto tid_in_warp = warp_tile.thread_rank ();
761
759
const auto slice_num = ceildiv (num_rows, slice_size);
762
760
763
761
size_type thread_result = 0 ;
@@ -767,9 +765,6 @@ __global__ __launch_bounds__(default_block_size) void reduce_max_nnz_per_slice(
767
765
max (thread_result, nnz_per_row[warpid * slice_size + i]);
768
766
}
769
767
}
770
-
771
- auto warp_tile =
772
- group ::tiled_partition< warp_size> (group ::this_thread_block ());
773
768
auto warp_result = reduce (
774
769
warp_tile, thread_result,
775
770
[](const size_type & a, const size_type & b) { return max (a, b); });
@@ -818,7 +813,7 @@ __global__
818
813
IndexType * __restrict__ csr_row_idxs,
819
814
size_type * __restrict__ coo_row_nnz)
820
815
{
821
- const auto tidx = threadIdx.x + blockIdx.x * blockDim.x ;
816
+ const auto tidx = thread :: get_thread_id_flat () ;
822
817
if (tidx < num_rows) {
823
818
const size_type csr_nnz = csr_row_idxs[tidx + 1 ] - csr_row_idxs[tidx];
824
819
coo_row_nnz[tidx] =
@@ -840,10 +835,9 @@ __global__ __launch_bounds__(default_block_size) void fill_in_hybrid(
840
835
IndexType * __restrict__ result_coo_col,
841
836
IndexType * __restrict__ result_coo_row)
842
837
{
843
- const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
844
838
constexpr auto warp_size = config ::warp_size;
845
- const auto row = tidx / warp_size;
846
- const auto local_tidx = tidx % warp_size;
839
+ const auto row = thread ::get_subwarp_id_flat < warp_size> () ;
840
+ const auto local_tidx = threadIdx.x % warp_size;
847
841
848
842
if (row < num_rows) {
849
843
for (size_type i = local_tidx;
@@ -876,7 +870,7 @@ template <typename ValueType>
876
870
__global__ __launch_bounds__ (default_block_size) void conjugate_kernel (
877
871
size_type num_nonzeros, ValueType * __restrict__ val)
878
872
{
879
- const auto tidx = size_type (blockIdx.x) * default_block_size + threadIdx.x ;
873
+ const auto tidx = thread :: get_thread_id_flat () ;
880
874
881
875
if (tidx < num_nonzeros) {
882
876
val[tidx] = conj (val[tidx]);
0 commit comments