Skip to content

Commit 9d8d97c

Browse files
committed
restore local workgroup size adjustments for large inputs
1 parent 8b8f705 commit 9d8d97c

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ggml/src/ggml-sycl/im2col.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
5858
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
5959
int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
6060
const int64_t parallel_elements = OW * KW * KH;
61-
const int64_t block_size_x = SYCL_IM2COL_BLOCK_SIZE;
62-
const int64_t num_groups_x = (parallel_elements + block_size_x - 1) / block_size_x;
61+
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
6362

64-
sycl::range<3> block_nums(batch * IC, OH, num_groups_x);
65-
sycl::range<3> local_range(1, 1, block_size_x);
63+
// decrease global range when it exceeds the max int
64+
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
65+
66+
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
67+
sycl::range<3> local_range(1, 1, local_size);
6668

6769
const int64_t CHW = IC * KH * KW;
6870

0 commit comments

Comments
 (0)