Skip to content

Commit 7e6fb8b

Browse files
committed
candle-core/candle-kernels: lift the limitation of 1024 for sorting on
cuda
1 parent 99f6c4c commit 7e6fb8b

File tree

2 files changed

+27
-26
lines changed

2 files changed

+27
-26
lines changed

candle-core/src/sort.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ mod cuda {
8585
let ncols = self.last_dim;
8686
let nrows = elem_count / ncols;
8787
let ncols_pad = next_power_of_2(ncols);
88+
// Limit block dim to 1024 threads, which is the maximum on modern CUDA gpus.
89+
let block_dim = ncols_pad.min(1024);
8890
let cfg = LaunchConfig {
8991
grid_dim: (nrows as u32, 1, 1),
90-
block_dim: (ncols_pad as u32, 1, 1),
92+
block_dim: (block_dim as u32, 1, 1),
9193
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
9294
};
9395
let stream = dev.cuda_stream();

candle-kernels/src/sort.cu

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,40 +14,39 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
1414
template<int order, typename T>
1515
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
1616
// bitonic sort
17-
int col = threadIdx.x;
1817
int row = blockIdx.x;
1918

20-
if (col >= ncols_pad) {
21-
return;
22-
}
23-
2419
const T * x_row = x + row * ncols;
2520
extern __shared__ int dst_row[];
2621

27-
// initialize indices
28-
dst_row[col] = col;
22+
// initialize indices - each thread handles multiple elements if ncols_pad > blockDim.x
23+
for (int col = threadIdx.x; col < ncols_pad; col += blockDim.x) {
24+
dst_row[col] = col;
25+
}
2926

3027
__syncthreads();
3128

3229
for (int k = 2; k <= ncols_pad; k *= 2) {
3330
for (int j = k / 2; j > 0; j /= 2) {
34-
int ixj = col ^ j;
35-
if (ixj > col) {
36-
if ((col & k) == 0) {
37-
if (dst_row[col] >= ncols ||
38-
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
39-
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
40-
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
41-
) {
42-
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
43-
}
44-
} else {
45-
if (dst_row[ixj] >= ncols ||
46-
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
47-
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
48-
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
49-
) {
50-
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
31+
for (int col = threadIdx.x; col < ncols_pad; col += blockDim.x) {
32+
int ixj = col ^ j;
33+
if (ixj > col) {
34+
if ((col & k) == 0) {
35+
if (dst_row[col] >= ncols ||
36+
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
37+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
38+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
39+
) {
40+
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
41+
}
42+
} else {
43+
if (dst_row[ixj] >= ncols ||
44+
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
45+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
46+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
47+
) {
48+
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
49+
}
5150
}
5251
}
5352
}
@@ -56,7 +55,7 @@ static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, i
5655
}
5756

5857
// copy the result to dst without the padding
59-
if (col < ncols) {
58+
for (int col = threadIdx.x; col < ncols; col += blockDim.x) {
6059
dst[row * ncols + col] = dst_row[col];
6160
}
6261
}

0 commit comments

Comments
 (0)