Skip to content

Commit 2c51661

Browse files
CUDA: mul_mat_vec_q for batch sizes > 1 (#5351)
1 parent 8a79c59 commit 2c51661

File tree

1 file changed

+98
-142
lines changed

1 file changed

+98
-142
lines changed

ggml-cuda.cu

+98-142
Original file line numberDiff line numberDiff line change
@@ -5310,41 +5310,50 @@ template <bool need_check> static __global__ void
53105310
#endif // __CUDA_ARCH__ >= CC_VOLTA
53115311
}
53125312

5313-
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5314-
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
5313+
template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5314+
static __global__ void mul_mat_vec_q(
5315+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
5317+
5318+
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5319+
53155320
const int row = blockIdx.x*blockDim.y + threadIdx.y;
53165321

5317-
if (row >= nrows) {
5322+
if (row >= nrows_x) {
53185323
return;
53195324
}
53205325

5321-
const int blocks_per_row = ncols / qk;
5326+
const int blocks_per_row_x = ncols_x / qk;
5327+
const int blocks_per_col_y = nrows_y / QK8_1;
53225328
const int blocks_per_warp = vdr * WARP_SIZE / qi;
53235329

53245330
// partial sum for each thread
5325-
float tmp = 0.0f;
5331+
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
53265332

53275333
const block_q_t * x = (const block_q_t *) vx;
53285334
const block_q8_1 * y = (const block_q8_1 *) vy;
53295335

5330-
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
5331-
const int ibx = row*blocks_per_row + i; // x block index
5336+
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
5337+
const int ibx = row*blocks_per_row_x + i; // x block index
53325338

53335339
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
53345340

53355341
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
53365342

5337-
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
5343+
#pragma unroll
5344+
for (int j = 0; j < ncols_y; ++j) {
5345+
tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
5346+
}
53385347
}
53395348

53405349
// sum up partial sums and write back result
53415350
#pragma unroll
5342-
for (int mask = 16; mask > 0; mask >>= 1) {
5343-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
5344-
}
5351+
for (int j = 0; j < ncols_y; ++j) {
5352+
tmp[j] = warp_reduce_sum(tmp[j]);
53455353

5346-
if (threadIdx.x == 0) {
5347-
dst[row] = tmp;
5354+
if (threadIdx.x == 0) {
5355+
dst[j*nrows_x + row] = tmp[j];
5356+
}
53485357
}
53495358
}
53505359

@@ -6816,121 +6825,56 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
68166825
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
68176826
}
68186827

6819-
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6820-
GGML_ASSERT(ncols % QK4_0 == 0);
6821-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6822-
const dim3 block_nums(block_num_y, 1, 1);
6823-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6824-
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
6825-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6826-
}
6827-
6828-
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6829-
GGML_ASSERT(ncols % QK4_1 == 0);
6830-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6831-
const dim3 block_nums(block_num_y, 1, 1);
6832-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6833-
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
6834-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6835-
}
6836-
6837-
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6838-
GGML_ASSERT(ncols % QK5_0 == 0);
6839-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6840-
const dim3 block_nums(block_num_y, 1, 1);
6841-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6842-
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
6843-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6844-
}
6845-
6846-
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6847-
GGML_ASSERT(ncols % QK5_1 == 0);
6848-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6849-
const dim3 block_nums(block_num_y, 1, 1);
6850-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6851-
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
6852-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6853-
}
6854-
6855-
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6856-
GGML_ASSERT(ncols % QK8_0 == 0);
6857-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6858-
const dim3 block_nums(block_num_y, 1, 1);
6859-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6860-
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
6861-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6862-
}
6863-
6864-
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6865-
GGML_ASSERT(ncols % QK_K == 0);
6866-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6867-
const dim3 block_nums(block_num_y, 1, 1);
6868-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6869-
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
6870-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6871-
}
6872-
6873-
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6874-
GGML_ASSERT(ncols % QK_K == 0);
6875-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6876-
const dim3 block_nums(block_num_y, 1, 1);
6877-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6878-
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
6879-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6880-
}
6881-
6882-
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6883-
GGML_ASSERT(ncols % QK_K == 0);
6884-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6885-
const dim3 block_nums(block_num_y, 1, 1);
6886-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6887-
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
6888-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6889-
}
6890-
6891-
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6892-
GGML_ASSERT(ncols % QK_K == 0);
6893-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6894-
const dim3 block_nums(block_num_y, 1, 1);
6895-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6896-
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
6897-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6898-
}
6899-
6900-
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6901-
GGML_ASSERT(ncols % QK_K == 0);
6902-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6903-
const dim3 block_nums(block_num_y, 1, 1);
6904-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6905-
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
6906-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6907-
}
6908-
6909-
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6910-
GGML_ASSERT(ncols % QK_K == 0);
6911-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6912-
const dim3 block_nums(block_num_y, 1, 1);
6913-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6914-
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
6915-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6916-
}
6828+
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
6829+
static void mul_mat_vec_q_cuda(
6830+
const void * vx, const void * vy, float * dst,
6831+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
69176832

6918-
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6919-
GGML_ASSERT(ncols % QK_K == 0);
6920-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6921-
const dim3 block_nums(block_num_y, 1, 1);
6922-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6923-
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
6924-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6925-
}
6833+
GGML_ASSERT(ncols_x % qk == 0);
6834+
GGML_ASSERT(ncols_y <= 8);
69266835

6927-
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6928-
GGML_ASSERT(ncols % QK_K == 0);
6929-
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6836+
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
69306837
const dim3 block_nums(block_num_y, 1, 1);
69316838
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6932-
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
6933-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6839+
switch (ncols_y) {
6840+
case 1:
6841+
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
6842+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6843+
break;
6844+
case 2:
6845+
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
6846+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6847+
break;
6848+
case 3:
6849+
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
6850+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6851+
break;
6852+
case 4:
6853+
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
6854+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6855+
break;
6856+
case 5:
6857+
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6859+
break;
6860+
case 6:
6861+
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6863+
break;
6864+
case 7:
6865+
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6867+
break;
6868+
case 8:
6869+
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6871+
break;
6872+
default:
6873+
GGML_ASSERT(false);
6874+
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6876+
break;
6877+
}
69346878
}
69356879

69366880
static void ggml_mul_mat_q4_0_q8_1_cuda(
@@ -8578,50 +8522,61 @@ static void ggml_cuda_op_mul_mat_vec_q(
85788522
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
85798523
const int64_t src1_padded_row_size, cudaStream_t stream) {
85808524

8581-
GGML_ASSERT(ggml_nrows(src1) == 1);
8582-
85838525
const int64_t ne00 = src0->ne[0];
85848526
const int64_t row_diff = row_high - row_low;
85858527

85868528
switch (src0->type) {
85878529
case GGML_TYPE_Q4_0:
8588-
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8530+
mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
8531+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
85898532
break;
85908533
case GGML_TYPE_Q4_1:
8591-
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8534+
mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
8535+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
85928536
break;
85938537
case GGML_TYPE_Q5_0:
8594-
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8538+
mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
8539+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
85958540
break;
85968541
case GGML_TYPE_Q5_1:
8597-
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8542+
mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
8543+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
85988544
break;
85998545
case GGML_TYPE_Q8_0:
8600-
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8546+
mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
8547+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86018548
break;
86028549
case GGML_TYPE_Q2_K:
8603-
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8550+
mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
8551+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86048552
break;
86058553
case GGML_TYPE_Q3_K:
8606-
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8554+
mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
8555+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86078556
break;
86088557
case GGML_TYPE_Q4_K:
8609-
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8558+
mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
8559+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86108560
break;
86118561
case GGML_TYPE_Q5_K:
8612-
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8562+
mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
8563+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86138564
break;
86148565
case GGML_TYPE_Q6_K:
8615-
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8566+
mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
8567+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86168568
break;
86178569
case GGML_TYPE_IQ2_XXS:
8618-
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8570+
mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
8571+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86198572
break;
86208573
case GGML_TYPE_IQ2_XS:
8621-
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8574+
mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
8575+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86228576
break;
86238577
case GGML_TYPE_IQ3_XXS:
8624-
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8578+
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
8579+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
86258580
break;
86268581
default:
86278582
GGML_ASSERT(false);
@@ -9945,17 +9900,18 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
99459900
#ifdef GGML_CUDA_FORCE_DMMV
99469901
const bool use_mul_mat_vec_q = false;
99479902
#else
9948-
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
9903+
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
99499904
#endif // GGML_CUDA_FORCE_DMMV
99509905

99519906
if (use_mul_mat_vec_q) {
9952-
// NOTE: this kernel does not support ggml_nrows(src1) > 1
99539907
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
99549908
} else {
99559909
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
99569910
}
99579911
} else {
9958-
if (use_mul_mat_q) {
9912+
if (src1->ne[1] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
9913+
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
9914+
} else if (use_mul_mat_q) {
99599915
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
99609916
} else {
99619917
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);

0 commit comments

Comments
 (0)