@@ -14,40 +14,39 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
1414template <int order, typename T>
1515static __device__ void k_argsort (const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
1616 // bitonic sort
17- int col = threadIdx .x ;
1817 int row = blockIdx .x ;
1918
20- if (col >= ncols_pad) {
21- return ;
22- }
23-
2419 const T * x_row = x + row * ncols;
2520 extern __shared__ int dst_row[];
2621
27- // initialize indices
28- dst_row[col] = col;
22+ // initialize indices - each thread handles multiple elements if ncols_pad > blockDim.x
23+ for (int col = threadIdx .x ; col < ncols_pad; col += blockDim .x ) {
24+ dst_row[col] = col;
25+ }
2926
3027 __syncthreads ();
3128
3229 for (int k = 2 ; k <= ncols_pad; k *= 2 ) {
3330 for (int j = k / 2 ; j > 0 ; j /= 2 ) {
34- int ixj = col ^ j;
35- if (ixj > col) {
36- if ((col & k) == 0 ) {
37- if (dst_row[col] >= ncols ||
38- (dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
39- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
40- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
41- ) {
42- ggml_cuda_swap (dst_row[col], dst_row[ixj]);
43- }
44- } else {
45- if (dst_row[ixj] >= ncols ||
46- (dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
47- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
48- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
49- ) {
50- ggml_cuda_swap (dst_row[col], dst_row[ixj]);
31+ for (int col = threadIdx .x ; col < ncols_pad; col += blockDim .x ) {
32+ int ixj = col ^ j;
33+ if (ixj > col) {
34+ if ((col & k) == 0 ) {
35+ if (dst_row[col] >= ncols ||
36+ (dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
37+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
38+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
39+ ) {
40+ ggml_cuda_swap (dst_row[col], dst_row[ixj]);
41+ }
42+ } else {
43+ if (dst_row[ixj] >= ncols ||
44+ (dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
45+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
46+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
47+ ) {
48+ ggml_cuda_swap (dst_row[col], dst_row[ixj]);
49+ }
5150 }
5251 }
5352 }
@@ -56,7 +55,7 @@ static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, i
5655 }
5756
5857 // copy the result to dst without the padding
59- if ( col < ncols) {
58+ for ( int col = threadIdx . x ; col < ncols; col += blockDim . x ) {
6059 dst[row * ncols + col] = dst_row[col];
6160 }
6261}
0 commit comments