Skip to content

Commit 46a7381

Browse files
authored
Extract thread ID computation from GPU kernels
This PR extracts most thread ID calculations into a separate function for improved overflow and type safety. Related PR: #464
2 parents be79c28 + c171a00 commit 46a7381

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+241
-165
lines changed

common/components/prefix_sum.hpp.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ __global__ __launch_bounds__(block_size) void start_prefix_sum(
5050
size_type num_elements, ValueType *__restrict__ elements,
5151
ValueType *__restrict__ block_sum)
5252
{
53-
const auto tidx = threadIdx.x + blockDim.x * blockIdx.x;
53+
const auto tidx = thread::get_thread_id_flat();
5454
const auto element_id = threadIdx.x;
5555
__shared__ size_type prefix_helper[block_size];
5656
prefix_helper[element_id] =
@@ -113,7 +113,7 @@ __global__ __launch_bounds__(block_size) void finalize_prefix_sum(
113113
size_type num_elements, ValueType *__restrict__ elements,
114114
const ValueType *__restrict__ block_sum)
115115
{
116-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
116+
const auto tidx = thread::get_thread_id_flat();
117117

118118
if (tidx < num_elements) {
119119
ValueType prefix_block_sum = zero<ValueType>();

common/components/reduction.hpp.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ __device__ void reduce_array(size_type size,
142142
ValueType *__restrict__ result,
143143
Operator reduce_op = Operator{})
144144
{
145-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
145+
const auto tidx = thread::get_thread_id_flat();
146146
auto thread_result = zero<ValueType>();
147147
for (auto i = tidx; i < size; i += blockDim.x * gridDim.x) {
148148
thread_result = reduce_op(thread_result, source[i]);

common/components/thread_ids.hpp.inc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,81 @@ __device__ __forceinline__ size_type get_thread_id()
192192
{
193193
return get_subwarp_id<subwarp_size, warps_per_block>() * subwarp_size +
194194
threadIdx.x;
195+
}
196+
197+
198+
/**
199+
* @internal
200+
*
201+
* Returns the global ID of the thread in the given index type.
202+
* This function assumes one-dimensional thread and block indexing.
203+
*
204+
* @return the global ID of the thread in the given index type.
205+
*
206+
* @tparam IndexType the index type
207+
*/
208+
template <typename IndexType = size_type>
209+
__device__ __forceinline__ IndexType get_thread_id_flat()
210+
{
211+
return threadIdx.x + static_cast<IndexType>(blockDim.x) * blockIdx.x;
212+
}
213+
214+
215+
/**
216+
* @internal
217+
*
218+
* Returns the total number of threads in the given index type.
219+
* This function assumes one-dimensional thread and block indexing.
220+
*
221+
* @return the total number of threads in the given index type.
222+
*
223+
* @tparam IndexType the index type
224+
*/
225+
template <typename IndexType = size_type>
226+
__device__ __forceinline__ IndexType get_thread_num_flat()
227+
{
228+
return blockDim.x * static_cast<IndexType>(gridDim.x);
229+
}
230+
231+
232+
/**
233+
* @internal
234+
*
235+
* Returns the global ID of the subwarp in the given index type.
236+
* This function assumes one-dimensional thread and block indexing
237+
* with a power of two block size of at least subwarp_size.
238+
*
239+
* @return the global ID of the subwarp in the given index type.
240+
*
241+
* @tparam subwarp_size the size of the subwarp. Must be a power of two!
242+
* @tparam IndexType the index type
243+
*/
244+
template <int subwarp_size, typename IndexType = size_type>
245+
__device__ __forceinline__ IndexType get_subwarp_id_flat()
246+
{
247+
static_assert(!(subwarp_size & (subwarp_size - 1)),
248+
"subwarp_size must be a power of two");
249+
return threadIdx.x / subwarp_size +
250+
static_cast<IndexType>(blockDim.x / subwarp_size) * blockIdx.x;
251+
}
252+
253+
254+
/**
255+
* @internal
256+
*
257+
* Returns the total number of subwarps in the given index type.
258+
* This function assumes one-dimensional thread and block indexing
259+
* with a power of two block size of at least subwarp_size.
260+
*
261+
* @return the total number of subwarps in the given index type.
262+
*
263+
* @tparam subwarp_size the size of the subwarp. Must be a power of two!
264+
* @tparam IndexType the index type
265+
*/
266+
template <int subwarp_size, typename IndexType = size_type>
267+
__device__ __forceinline__ IndexType get_subwarp_num_flat()
268+
{
269+
static_assert(!(subwarp_size & (subwarp_size - 1)),
270+
"subwarp_size must be a power of two");
271+
return blockDim.x / subwarp_size * static_cast<IndexType>(gridDim.x);
195272
}

common/components/zero_array.hpp.inc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ template <typename ValueType>
3838
__global__ __launch_bounds__(default_block_size) void zero_array(
3939
size_type n, ValueType *__restrict__ array)
4040
{
41-
const auto tidx =
42-
static_cast<size_type>(blockDim.x) * blockIdx.x + threadIdx.x;
41+
const auto tidx = thread::get_thread_id_flat();
4342
if (tidx < n) {
4443
array[tidx] = zero<ValueType>();
4544
}

common/factorization/par_ilu_kernels.hpp.inc

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,17 @@ __global__
9393
IndexType *__restrict__ elements_to_add_per_row,
9494
bool *__restrict__ changes_required)
9595
{
96-
const auto total_thread_count =
97-
static_cast<size_type>(blockDim.x) * gridDim.x / SubwarpSize;
98-
const auto tidx =
99-
threadIdx.x + static_cast<size_type>(blockIdx.x) * blockDim.x;
100-
const auto begin_row = static_cast<IndexType>(tidx / SubwarpSize);
96+
const auto total_subwarp_count =
97+
thread::get_subwarp_num_flat<SubwarpSize, IndexType>();
98+
const auto begin_row =
99+
thread::get_subwarp_id_flat<SubwarpSize, IndexType>();
101100

102101
auto thread_block = group::this_thread_block();
103102
auto subwarp_grp = group::tiled_partition<SubwarpSize>(thread_block);
104103
const auto subwarp_idx = subwarp_grp.thread_rank();
105104

106105
bool local_change{false};
107-
for (IndexType row = begin_row; row < num_rows; row += total_thread_count) {
106+
for (auto row = begin_row; row < num_rows; row += total_subwarp_count) {
108107
if (row >= num_cols) {
109108
if (subwarp_idx == 0) {
110109
elements_to_add_per_row[row] = 0;
@@ -145,17 +144,16 @@ __global__
145144
const IndexType *__restrict__ row_ptrs_addition)
146145
{
147146
// Precaution in case not enough threads were created
148-
const auto total_thread_count =
149-
static_cast<size_type>(blockDim.x) * gridDim.x / SubwarpSize;
150-
const auto tidx =
151-
threadIdx.x + static_cast<size_type>(blockIdx.x) * blockDim.x;
152-
const auto begin_row = static_cast<IndexType>(tidx / SubwarpSize);
147+
const auto total_subwarp_count =
148+
thread::get_subwarp_num_flat<SubwarpSize, IndexType>();
149+
const auto begin_row =
150+
thread::get_subwarp_id_flat<SubwarpSize, IndexType>();
153151

154152
auto thread_block = group::this_thread_block();
155153
auto subwarp_grp = group::tiled_partition<SubwarpSize>(thread_block);
156154
const auto subwarp_idx = subwarp_grp.thread_rank();
157155

158-
for (IndexType row = begin_row; row < num_rows; row += total_thread_count) {
156+
for (auto row = begin_row; row < num_rows; row += total_subwarp_count) {
159157
const IndexType old_row_start{old_row_ptrs[row]};
160158
const IndexType old_row_end{old_row_ptrs[row + 1]};
161159
const IndexType new_row_start{old_row_start + row_ptrs_addition[row]};
@@ -223,12 +221,10 @@ __global__ __launch_bounds__(default_block_size) void update_row_ptrs(
223221
IndexType num_rows, IndexType *__restrict__ row_ptrs,
224222
IndexType *__restrict__ row_ptr_addition)
225223
{
226-
const auto total_thread_count =
227-
static_cast<size_type>(blockDim.x) * gridDim.x;
228-
const auto begin_row =
229-
threadIdx.x + static_cast<size_type>(blockIdx.x) * blockDim.x;
224+
const auto total_thread_count = thread::get_thread_num_flat<IndexType>();
225+
const auto begin_row = thread::get_thread_id_flat<IndexType>();
230226

231-
for (IndexType row = begin_row; row < num_rows; row += total_thread_count) {
227+
for (auto row = begin_row; row < num_rows; row += total_thread_count) {
232228
row_ptrs[row] += row_ptr_addition[row];
233229
}
234230
}
@@ -241,7 +237,7 @@ __global__ __launch_bounds__(default_block_size) void count_nnz_per_l_u_row(
241237
const ValueType *__restrict__ values, IndexType *__restrict__ l_nnz_row,
242238
IndexType *__restrict__ u_nnz_row)
243239
{
244-
const auto row = blockDim.x * blockIdx.x + threadIdx.x;
240+
const auto row = thread::get_thread_id_flat<IndexType>();
245241
if (row < num_rows) {
246242
IndexType l_row_nnz{};
247243
IndexType u_row_nnz{};
@@ -266,7 +262,7 @@ __global__ __launch_bounds__(default_block_size) void initialize_l_u(
266262
const IndexType *__restrict__ u_row_ptrs,
267263
IndexType *__restrict__ u_col_idxs, ValueType *__restrict__ u_values)
268264
{
269-
const auto row = blockDim.x * blockIdx.x + threadIdx.x;
265+
const auto row = thread::get_thread_id_flat<IndexType>();
270266
if (row < num_rows) {
271267
auto l_idx = l_row_ptrs[row];
272268
auto u_idx = u_row_ptrs[row];
@@ -298,7 +294,7 @@ __global__ __launch_bounds__(default_block_size) void compute_l_u_factors(
298294
const IndexType *__restrict__ u_row_ptrs,
299295
const IndexType *__restrict__ u_col_idxs, ValueType *__restrict__ u_values)
300296
{
301-
const auto elem_id = blockDim.x * blockIdx.x + threadIdx.x;
297+
const auto elem_id = thread::get_thread_id_flat<IndexType>();
302298
if (elem_id < num_elements) {
303299
const auto row = row_idxs[elem_id];
304300
const auto col = col_idxs[elem_id];

common/matrix/coo_kernels.hpp.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ __global__ __launch_bounds__(default_block_size) void convert_row_idxs_to_ptrs(
228228
const IndexType *__restrict__ idxs, size_type num_nonzeros,
229229
IndexType *__restrict__ ptrs, size_type length)
230230
{
231-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
231+
const auto tidx = thread::get_thread_id_flat();
232232

233233
if (tidx == 0) {
234234
ptrs[0] = 0;
@@ -265,7 +265,7 @@ __global__ __launch_bounds__(default_block_size) void fill_in_dense(
265265
const ValueType *__restrict__ values, size_type stride,
266266
ValueType *__restrict__ result)
267267
{
268-
const auto tidx = threadIdx.x + blockDim.x * blockIdx.x;
268+
const auto tidx = thread::get_thread_id_flat();
269269
if (tidx < nnz) {
270270
result[stride * row_idxs[tidx] + col_idxs[tidx]] = values[tidx];
271271
}

common/matrix/csr_kernels.hpp.inc

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ template <typename ValueType>
222222
__global__ __launch_bounds__(default_block_size) void set_zero(
223223
const size_type nnz, ValueType *__restrict__ val)
224224
{
225-
const auto ind = size_type(blockDim.x) * blockIdx.x + threadIdx.x;
225+
const auto ind = thread::get_thread_id_flat();
226226
if (ind < nnz) {
227227
val[ind] = zero<ValueType>();
228228
}
@@ -438,19 +438,19 @@ __device__ void device_classical_spmv(const size_type num_rows,
438438
ValueType *__restrict__ c,
439439
const size_type c_stride, Closure scale)
440440
{
441-
const auto tid = size_type(blockDim.x) * blockIdx.x + threadIdx.x;
442-
const auto subrow = size_type(gridDim.x) * blockDim.x / subwarp_size;
443-
const auto subid = tid % subwarp_size;
441+
auto subwarp_tile =
442+
group::tiled_partition<subwarp_size>(group::this_thread_block());
443+
const auto subrow = thread::get_subwarp_num_flat<subwarp_size>();
444+
const auto subid = subwarp_tile.thread_rank();
444445
const auto column_id = blockIdx.y;
445-
for (auto row = tid / subwarp_size; row < num_rows; row += subrow) {
446+
auto row = thread::get_subwarp_id_flat<subwarp_size>();
447+
for (; row < num_rows; row += subrow) {
446448
const auto ind_end = row_ptrs[row + 1];
447449
ValueType temp_val = zero<ValueType>();
448450
for (auto ind = row_ptrs[row] + subid; ind < ind_end;
449451
ind += subwarp_size) {
450452
temp_val += val[ind] * b[col_idxs[ind] * b_stride + column_id];
451453
}
452-
auto subwarp_tile =
453-
group::tiled_partition<subwarp_size>(group::this_thread_block());
454454
auto subwarp_result = reduce(
455455
subwarp_tile, temp_val,
456456
[](const ValueType &a, const ValueType &b) { return a + b; });
@@ -500,8 +500,7 @@ __global__ __launch_bounds__(default_block_size) void spgeam_nnz(
500500
const IndexType *b_row_ptrs, const IndexType *b_col_idxs,
501501
IndexType num_rows, IndexType *nnz)
502502
{
503-
auto row = (threadIdx.x + blockDim.x * static_cast<size_type>(blockIdx.x)) /
504-
subwarp_size;
503+
auto row = thread::get_subwarp_id_flat<subwarp_size, IndexType>();
505504
auto subwarp =
506505
group::tiled_partition<subwarp_size>(group::this_thread_block());
507506
if (row >= num_rows) {
@@ -533,8 +532,7 @@ __global__ __launch_bounds__(default_block_size) void spgeam(
533532
const IndexType *b_col_idxs, const ValueType *b_vals, IndexType num_rows,
534533
const IndexType *c_row_ptrs, IndexType *c_col_idxs, ValueType *c_vals)
535534
{
536-
auto row = (threadIdx.x + blockDim.x * static_cast<size_type>(blockIdx.x)) /
537-
subwarp_size;
535+
auto row = thread::get_subwarp_id_flat<subwarp_size, IndexType>();
538536
auto subwarp =
539537
group::tiled_partition<subwarp_size>(group::this_thread_block());
540538
if (row >= num_rows) {
@@ -591,7 +589,7 @@ __global__ __launch_bounds__(default_block_size) void convert_row_ptrs_to_idxs(
591589
size_type num_rows, const IndexType *__restrict__ ptrs,
592590
IndexType *__restrict__ idxs)
593591
{
594-
const auto tidx = threadIdx.x + blockDim.x * blockIdx.x;
592+
const auto tidx = thread::get_thread_id_flat();
595593
if (tidx < num_rows) {
596594
for (auto i = ptrs[tidx]; i < ptrs[tidx + 1]; i++) {
597595
idxs[i] = tidx;
@@ -620,7 +618,7 @@ __global__ __launch_bounds__(default_block_size) void fill_in_dense(
620618
const ValueType *__restrict__ values, size_type stride,
621619
ValueType *__restrict__ result)
622620
{
623-
const auto tidx = threadIdx.x + blockDim.x * blockIdx.x;
621+
const auto tidx = thread::get_thread_id_flat();
624622
if (tidx < num_rows) {
625623
for (auto i = row_ptrs[tidx]; i < row_ptrs[tidx + 1]; i++) {
626624
result[stride * tidx + col_idxs[i]] = values[i];
@@ -634,7 +632,7 @@ __global__ __launch_bounds__(default_block_size) void calculate_nnz_per_row(
634632
size_type num_rows, const IndexType *__restrict__ row_ptrs,
635633
size_type *__restrict__ nnz_per_row)
636634
{
637-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
635+
const auto tidx = thread::get_thread_id_flat();
638636
if (tidx < num_rows) {
639637
nnz_per_row[tidx] = row_ptrs[tidx + 1] - row_ptrs[tidx];
640638
}
@@ -685,7 +683,7 @@ __global__ __launch_bounds__(default_block_size) void fill_in_sellp(
685683
IndexType *__restrict__ result_col_idxs,
686684
ValueType *__restrict__ result_values)
687685
{
688-
const auto global_row = threadIdx.x + blockIdx.x * blockDim.x;
686+
const auto global_row = thread::get_thread_id_flat();
689687
const auto row = global_row % slice_size;
690688
const auto sliceid = global_row / slice_size;
691689

@@ -714,7 +712,7 @@ __global__ __launch_bounds__(default_block_size) void initialize_zero_ell(
714712
size_type max_nnz_per_row, size_type stride, ValueType *__restrict__ values,
715713
IndexType *__restrict__ col_idxs)
716714
{
717-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
715+
const auto tidx = thread::get_thread_id_flat();
718716

719717
if (tidx < stride * max_nnz_per_row) {
720718
values[tidx] = zero<ValueType>();
@@ -732,10 +730,9 @@ __global__ __launch_bounds__(default_block_size) void fill_in_ell(
732730
ValueType *__restrict__ result_values,
733731
IndexType *__restrict__ result_col_idxs)
734732
{
735-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
736733
constexpr auto warp_size = config::warp_size;
737-
const auto row = tidx / warp_size;
738-
const auto local_tidx = tidx % warp_size;
734+
const auto row = thread::get_subwarp_id_flat<warp_size>();
735+
const auto local_tidx = threadIdx.x % warp_size;
739736

740737
if (row < num_rows) {
741738
for (size_type i = local_tidx;
@@ -754,10 +751,11 @@ __global__ __launch_bounds__(default_block_size) void reduce_max_nnz_per_slice(
754751
size_type num_rows, size_type slice_size, size_type stride_factor,
755752
const size_type *__restrict__ nnz_per_row, size_type *__restrict__ result)
756753
{
757-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
758754
constexpr auto warp_size = config::warp_size;
759-
const auto warpid = tidx / warp_size;
760-
const auto tid_in_warp = tidx % warp_size;
755+
auto warp_tile =
756+
group::tiled_partition<warp_size>(group::this_thread_block());
757+
const auto warpid = thread::get_subwarp_id_flat<warp_size>();
758+
const auto tid_in_warp = warp_tile.thread_rank();
761759
const auto slice_num = ceildiv(num_rows, slice_size);
762760

763761
size_type thread_result = 0;
@@ -767,9 +765,6 @@ __global__ __launch_bounds__(default_block_size) void reduce_max_nnz_per_slice(
767765
max(thread_result, nnz_per_row[warpid * slice_size + i]);
768766
}
769767
}
770-
771-
auto warp_tile =
772-
group::tiled_partition<warp_size>(group::this_thread_block());
773768
auto warp_result = reduce(
774769
warp_tile, thread_result,
775770
[](const size_type &a, const size_type &b) { return max(a, b); });
@@ -818,7 +813,7 @@ __global__
818813
IndexType *__restrict__ csr_row_idxs,
819814
size_type *__restrict__ coo_row_nnz)
820815
{
821-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
816+
const auto tidx = thread::get_thread_id_flat();
822817
if (tidx < num_rows) {
823818
const size_type csr_nnz = csr_row_idxs[tidx + 1] - csr_row_idxs[tidx];
824819
coo_row_nnz[tidx] =
@@ -840,10 +835,9 @@ __global__ __launch_bounds__(default_block_size) void fill_in_hybrid(
840835
IndexType *__restrict__ result_coo_col,
841836
IndexType *__restrict__ result_coo_row)
842837
{
843-
const auto tidx = threadIdx.x + blockIdx.x * blockDim.x;
844838
constexpr auto warp_size = config::warp_size;
845-
const auto row = tidx / warp_size;
846-
const auto local_tidx = tidx % warp_size;
839+
const auto row = thread::get_subwarp_id_flat<warp_size>();
840+
const auto local_tidx = threadIdx.x % warp_size;
847841

848842
if (row < num_rows) {
849843
for (size_type i = local_tidx;
@@ -876,7 +870,7 @@ template <typename ValueType>
876870
__global__ __launch_bounds__(default_block_size) void conjugate_kernel(
877871
size_type num_nonzeros, ValueType *__restrict__ val)
878872
{
879-
const auto tidx = size_type(blockIdx.x) * default_block_size + threadIdx.x;
873+
const auto tidx = thread::get_thread_id_flat();
880874

881875
if (tidx < num_nonzeros) {
882876
val[tidx] = conj(val[tidx]);

0 commit comments

Comments
 (0)