@@ -5310,41 +5310,50 @@ template <bool need_check> static __global__ void
5310
5310
#endif // __CUDA_ARCH__ >= CC_VOLTA
5311
5311
}
5312
5312
5313
- template <int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5314
- static __global__ void mul_mat_vec_q (const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
5313
+ template <int ncols_y_template, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5314
+ static __global__ void mul_mat_vec_q (
5315
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
5317
+
5318
+ const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5319
+
5315
5320
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
5316
5321
5317
- if (row >= nrows ) {
5322
+ if (row >= nrows_x ) {
5318
5323
return ;
5319
5324
}
5320
5325
5321
- const int blocks_per_row = ncols / qk;
5326
+ const int blocks_per_row_x = ncols_x / qk;
5327
+ const int blocks_per_col_y = nrows_y / QK8_1;
5322
5328
const int blocks_per_warp = vdr * WARP_SIZE / qi;
5323
5329
5324
5330
// partial sum for each thread
5325
- float tmp = 0 .0f ;
5331
+ float tmp[ncols_y_template ! = 0 ? ncols_y_template : 8 ] = { 0 .0f } ;
5326
5332
5327
5333
const block_q_t * x = (const block_q_t *) vx;
5328
5334
const block_q8_1 * y = (const block_q8_1 *) vy;
5329
5335
5330
- for (int i = threadIdx .x / (qi/vdr); i < blocks_per_row ; i += blocks_per_warp) {
5331
- const int ibx = row*blocks_per_row + i; // x block index
5336
+ for (int i = threadIdx .x / (qi/vdr); i < blocks_per_row_x ; i += blocks_per_warp) {
5337
+ const int ibx = row*blocks_per_row_x + i; // x block index
5332
5338
5333
5339
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
5334
5340
5335
5341
const int iqs = vdr * (threadIdx .x % (qi/vdr)); // x block quant index when casting the quants to int
5336
5342
5337
- tmp += vec_dot_q_cuda (&x[ibx], &y[iby], iqs);
5343
+ #pragma unroll
5344
+ for (int j = 0 ; j < ncols_y; ++j) {
5345
+ tmp[j] += vec_dot_q_cuda (&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
5346
+ }
5338
5347
}
5339
5348
5340
5349
// sum up partial sums and write back result
5341
5350
#pragma unroll
5342
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
5343
- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
5344
- }
5351
+ for (int j = 0 ; j < ncols_y; ++j) {
5352
+ tmp[j] = warp_reduce_sum (tmp[j]);
5345
5353
5346
- if (threadIdx .x == 0 ) {
5347
- dst[row] = tmp;
5354
+ if (threadIdx .x == 0 ) {
5355
+ dst[j*nrows_x + row] = tmp[j];
5356
+ }
5348
5357
}
5349
5358
}
5350
5359
@@ -6816,121 +6825,56 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
6816
6825
<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
6817
6826
}
6818
6827
6819
- static void mul_mat_vec_q4_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6820
- GGML_ASSERT (ncols % QK4_0 == 0 );
6821
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6822
- const dim3 block_nums (block_num_y, 1 , 1 );
6823
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6824
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
6825
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6826
- }
6827
-
6828
- static void mul_mat_vec_q4_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6829
- GGML_ASSERT (ncols % QK4_1 == 0 );
6830
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6831
- const dim3 block_nums (block_num_y, 1 , 1 );
6832
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6833
- mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
6834
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6835
- }
6836
-
6837
- static void mul_mat_vec_q5_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6838
- GGML_ASSERT (ncols % QK5_0 == 0 );
6839
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6840
- const dim3 block_nums (block_num_y, 1 , 1 );
6841
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6842
- mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
6843
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6844
- }
6845
-
6846
- static void mul_mat_vec_q5_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6847
- GGML_ASSERT (ncols % QK5_1 == 0 );
6848
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6849
- const dim3 block_nums (block_num_y, 1 , 1 );
6850
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6851
- mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
6852
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6853
- }
6854
-
6855
- static void mul_mat_vec_q8_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6856
- GGML_ASSERT (ncols % QK8_0 == 0 );
6857
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6858
- const dim3 block_nums (block_num_y, 1 , 1 );
6859
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6860
- mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
6861
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6862
- }
6863
-
6864
- static void mul_mat_vec_q2_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6865
- GGML_ASSERT (ncols % QK_K == 0 );
6866
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6867
- const dim3 block_nums (block_num_y, 1 , 1 );
6868
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6869
- mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
6870
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6871
- }
6872
-
6873
- static void mul_mat_vec_q3_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6874
- GGML_ASSERT (ncols % QK_K == 0 );
6875
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6876
- const dim3 block_nums (block_num_y, 1 , 1 );
6877
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6878
- mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
6879
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6880
- }
6881
-
6882
- static void mul_mat_vec_q4_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6883
- GGML_ASSERT (ncols % QK_K == 0 );
6884
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6885
- const dim3 block_nums (block_num_y, 1 , 1 );
6886
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6887
- mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
6888
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6889
- }
6890
-
6891
- static void mul_mat_vec_q5_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6892
- GGML_ASSERT (ncols % QK_K == 0 );
6893
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6894
- const dim3 block_nums (block_num_y, 1 , 1 );
6895
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6896
- mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
6897
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6898
- }
6899
-
6900
- static void mul_mat_vec_q6_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6901
- GGML_ASSERT (ncols % QK_K == 0 );
6902
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6903
- const dim3 block_nums (block_num_y, 1 , 1 );
6904
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6905
- mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
6906
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6907
- }
6908
-
6909
- static void mul_mat_vec_iq2_xxs_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6910
- GGML_ASSERT (ncols % QK_K == 0 );
6911
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6912
- const dim3 block_nums (block_num_y, 1 , 1 );
6913
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6914
- mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1 , vec_dot_iq2_xxs_q8_1>
6915
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6916
- }
6828
+ template <int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot>
6829
+ static void mul_mat_vec_q_cuda (
6830
+ const void * vx, const void * vy, float * dst,
6831
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
6917
6832
6918
- static void mul_mat_vec_iq2_xs_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6919
- GGML_ASSERT (ncols % QK_K == 0 );
6920
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6921
- const dim3 block_nums (block_num_y, 1 , 1 );
6922
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6923
- mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1 , vec_dot_iq2_xs_q8_1>
6924
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6925
- }
6833
+ GGML_ASSERT (ncols_x % qk == 0 );
6834
+ GGML_ASSERT (ncols_y <= 8 );
6926
6835
6927
- static void mul_mat_vec_iq3_xxs_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6928
- GGML_ASSERT (ncols % QK_K == 0 );
6929
- const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6836
+ const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6930
6837
const dim3 block_nums (block_num_y, 1 , 1 );
6931
6838
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6932
- mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1 , vec_dot_iq3_xxs_q8_1>
6933
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6839
+ switch (ncols_y) {
6840
+ case 1 :
6841
+ mul_mat_vec_q<1 , qk, qi, block_q_t , vdr, vec_dot>
6842
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6843
+ break ;
6844
+ case 2 :
6845
+ mul_mat_vec_q<2 , qk, qi, block_q_t , vdr, vec_dot>
6846
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6847
+ break ;
6848
+ case 3 :
6849
+ mul_mat_vec_q<3 , qk, qi, block_q_t , vdr, vec_dot>
6850
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6851
+ break ;
6852
+ case 4 :
6853
+ mul_mat_vec_q<4 , qk, qi, block_q_t , vdr, vec_dot>
6854
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6855
+ break ;
6856
+ case 5 :
6857
+ mul_mat_vec_q<5 , qk, qi, block_q_t , vdr, vec_dot>
6858
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6859
+ break ;
6860
+ case 6 :
6861
+ mul_mat_vec_q<6 , qk, qi, block_q_t , vdr, vec_dot>
6862
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6863
+ break ;
6864
+ case 7 :
6865
+ mul_mat_vec_q<7 , qk, qi, block_q_t , vdr, vec_dot>
6866
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6867
+ break ;
6868
+ case 8 :
6869
+ mul_mat_vec_q<8 , qk, qi, block_q_t , vdr, vec_dot>
6870
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6871
+ break ;
6872
+ default :
6873
+ GGML_ASSERT (false );
6874
+ // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875
+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6876
+ break ;
6877
+ }
6934
6878
}
6935
6879
6936
6880
static void ggml_mul_mat_q4_0_q8_1_cuda (
@@ -8578,50 +8522,61 @@ static void ggml_cuda_op_mul_mat_vec_q(
8578
8522
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
8579
8523
const int64_t src1_padded_row_size, cudaStream_t stream) {
8580
8524
8581
- GGML_ASSERT (ggml_nrows (src1) == 1 );
8582
-
8583
8525
const int64_t ne00 = src0->ne [0 ];
8584
8526
const int64_t row_diff = row_high - row_low;
8585
8527
8586
8528
switch (src0->type ) {
8587
8529
case GGML_TYPE_Q4_0:
8588
- mul_mat_vec_q4_0_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8530
+ mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
8531
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8589
8532
break ;
8590
8533
case GGML_TYPE_Q4_1:
8591
- mul_mat_vec_q4_1_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8534
+ mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
8535
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8592
8536
break ;
8593
8537
case GGML_TYPE_Q5_0:
8594
- mul_mat_vec_q5_0_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8538
+ mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
8539
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8595
8540
break ;
8596
8541
case GGML_TYPE_Q5_1:
8597
- mul_mat_vec_q5_1_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8542
+ mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
8543
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8598
8544
break ;
8599
8545
case GGML_TYPE_Q8_0:
8600
- mul_mat_vec_q8_0_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8546
+ mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
8547
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8601
8548
break ;
8602
8549
case GGML_TYPE_Q2_K:
8603
- mul_mat_vec_q2_K_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8550
+ mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
8551
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8604
8552
break ;
8605
8553
case GGML_TYPE_Q3_K:
8606
- mul_mat_vec_q3_K_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8554
+ mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
8555
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8607
8556
break ;
8608
8557
case GGML_TYPE_Q4_K:
8609
- mul_mat_vec_q4_K_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8558
+ mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
8559
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8610
8560
break ;
8611
8561
case GGML_TYPE_Q5_K:
8612
- mul_mat_vec_q5_K_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8562
+ mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
8563
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8613
8564
break ;
8614
8565
case GGML_TYPE_Q6_K:
8615
- mul_mat_vec_q6_K_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8566
+ mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
8567
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8616
8568
break ;
8617
8569
case GGML_TYPE_IQ2_XXS:
8618
- mul_mat_vec_iq2_xxs_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8570
+ mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1 , vec_dot_iq2_xxs_q8_1>
8571
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8619
8572
break ;
8620
8573
case GGML_TYPE_IQ2_XS:
8621
- mul_mat_vec_iq2_xs_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8574
+ mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1 , vec_dot_iq2_xs_q8_1>
8575
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8622
8576
break ;
8623
8577
case GGML_TYPE_IQ3_XXS:
8624
- mul_mat_vec_iq3_xxs_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8578
+ mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1 , vec_dot_iq3_xxs_q8_1>
8579
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8625
8580
break ;
8626
8581
default :
8627
8582
GGML_ASSERT (false );
@@ -9945,17 +9900,18 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
9945
9900
#ifdef GGML_CUDA_FORCE_DMMV
9946
9901
const bool use_mul_mat_vec_q = false ;
9947
9902
#else
9948
- const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type ) && ggml_nrows (src1) == 1 ;
9903
+ const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type );
9949
9904
#endif // GGML_CUDA_FORCE_DMMV
9950
9905
9951
9906
if (use_mul_mat_vec_q) {
9952
- // NOTE: this kernel does not support ggml_nrows(src1) > 1
9953
9907
ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true );
9954
9908
} else {
9955
9909
ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false );
9956
9910
}
9957
9911
} else {
9958
- if (use_mul_mat_q) {
9912
+ if (src1->ne [1 ] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type )) {
9913
+ ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true );
9914
+ } else if (use_mul_mat_q) {
9959
9915
ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_q, true );
9960
9916
} else {
9961
9917
ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false );
0 commit comments