Skip to content

Commit a86f52b

Browse files
CUDA: fix overflow in FA, tune performance (ggml-org#14840)
1 parent b284197 commit a86f52b

File tree

8 files changed

+98
-246
lines changed

8 files changed

+98
-246
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)(
2323
const float m1,
2424
const uint32_t n_head_log2,
2525
const float logit_softcap,
26-
const int ne00,
27-
const int ne01,
28-
const int ne02,
29-
const int ne03,
30-
const int ne10,
31-
const int ne11,
32-
const int ne12,
33-
const int ne13,
34-
const int ne31,
35-
const int ne32,
36-
const int ne33,
37-
const int nb31,
38-
const int nb32,
39-
const int nb33,
40-
const int nb01,
41-
const int nb02,
42-
const int nb03,
43-
const int nb11,
44-
const int nb12,
45-
const int nb13,
46-
const int nb21,
47-
const int nb22,
48-
const int nb23,
49-
const int ne0,
50-
const int ne1,
51-
const int ne2,
52-
const int ne3);
26+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
27+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
28+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
29+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
30+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
31+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
32+
const int32_t nb31, const int32_t nb32, const int64_t nb33);
5333

5434
typedef half (*vec_dot_KQ_f16_t)(
5535
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -892,14 +872,11 @@ void launch_fattn(
892872
mask ? ((const char *) mask->data) : nullptr,
893873
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
894874
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
895-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
896-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
897-
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
898-
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
899-
Q->nb[1], Q->nb[2], Q->nb[3],
900-
nb11, nb12, nb13,
875+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
876+
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
901877
nb21, nb22, nb23,
902-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
878+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
879+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
903880
);
904881
CUDA_CHECK(cudaGetLastError());
905882

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
408408
const int stride_K,
409409
const int stride_V,
410410
const int stride_mask,
411-
const int jt,
412411
half2 * const __restrict__ tile_Q,
413412
half2 * const __restrict__ tile_K,
414413
half2 * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
455454
cp_async_wait_all();
456455
__syncthreads();
457456
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458-
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
457+
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
459458
} else {
460459
constexpr bool use_cp_async = nstages == 1;
461460
if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
471470
if (nstages <= 1) {
472471
constexpr bool use_cp_async = nstages == 1;
473472
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474-
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
473+
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
475474
if (use_cp_async) {
476475
cp_async_wait_all();
477476
}
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
715714
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
716715
}
717716
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718-
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
717+
(K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
719718
}
720719
}
721720

@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
732731
if (nstages <= 1 && i0_start < reusable_cutoff) {
733732
constexpr bool use_cp_async = nstages == 1;
734733
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735-
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
734+
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
736735
if (use_cp_async) {
737736
cp_async_wait_all();
738737
}
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
771770
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
772771
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
773772
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
774-
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
775-
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
773+
GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
776774
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
777775
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
778776
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
@@ -920,21 +918,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
920918
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
921919
}
922920
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
923-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921+
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
924922
}
925923

926924
// Iterate over ne11 == previous tokens:
927925
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
928926
constexpr bool last_iter = false;
929927
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
930928
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
931-
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
929+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
932930
}
933931
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
934932
constexpr bool last_iter = true;
935933
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
936934
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
937-
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
935+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
938936
}
939937

940938
// With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
12141212
const float m1,
12151213
const uint32_t n_head_log2,
12161214
const float logit_softcap,
1217-
const int ne00,
1218-
const int ne01,
1219-
const int ne02,
1220-
const int ne03,
1221-
const int ne10,
1222-
const int ne11,
1223-
const int ne12,
1224-
const int ne13,
1225-
const int ne31,
1226-
const int ne32,
1227-
const int ne33,
1228-
const int nb31,
1229-
const int nb32,
1230-
const int nb33,
1231-
const int nb01,
1232-
const int nb02,
1233-
const int nb03,
1234-
const int nb11,
1235-
const int nb12,
1236-
const int nb13,
1237-
const int nb21,
1238-
const int nb22,
1239-
const int nb23,
1240-
const int ne0,
1241-
const int ne1,
1242-
const int ne2,
1243-
const int ne3) {
1215+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1216+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
1217+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1218+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
1219+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
1220+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
1221+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
12441222
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
12451223

12461224
// Skip unused kernel variants for faster compilation:
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
13591337
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
13601338
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
13611339
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1362-
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
1363-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
1340+
GGML_UNUSED(nb22); GGML_UNUSED(nb23);
13641341
NO_DEVICE_CODE;
13651342
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
13661343
}

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
2121
const float m1,
2222
const uint32_t n_head_log2,
2323
const float logit_softcap,
24-
const int ne00,
25-
const int ne01,
26-
const int ne02,
27-
const int ne03,
28-
const int ne10,
29-
const int ne11,
30-
const int ne12,
31-
const int ne13,
32-
const int ne31,
33-
const int ne32,
34-
const int ne33,
35-
const int nb31,
36-
const int nb32,
37-
const int nb33,
38-
const int nb01,
39-
const int nb02,
40-
const int nb03,
41-
const int nb11,
42-
const int nb12,
43-
const int nb13,
44-
const int nb21,
45-
const int nb22,
46-
const int nb23,
47-
const int ne0,
48-
const int ne1,
49-
const int ne2,
50-
const int ne3) {
24+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
25+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
26+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
27+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
28+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
29+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
30+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
5131
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
5232

5333
// Skip unused kernel variants for faster compilation:
@@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
127107
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
128108
const int k_KQ = k_KQ_0 + threadIdx.x;
129109

130-
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
110+
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
131111
}
132112
}
133113

@@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
221201
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
222202
const int i = i0 + threadIdx.x;
223203

224-
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
204+
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
225205
}
226206
}
227207

@@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16(
300280
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
301281
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
302282
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
303-
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
304-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
283+
GGML_UNUSED(nb23);
305284
NO_DEVICE_CODE;
306285
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
307286
}

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
2121
const float m1,
2222
const uint32_t n_head_log2,
2323
const float logit_softcap,
24-
const int ne00,
25-
const int ne01,
26-
const int ne02,
27-
const int ne03,
28-
const int ne10,
29-
const int ne11,
30-
const int ne12,
31-
const int ne13,
32-
const int ne31,
33-
const int ne32,
34-
const int ne33,
35-
const int nb31,
36-
const int nb32,
37-
const int nb33,
38-
const int nb01,
39-
const int nb02,
40-
const int nb03,
41-
const int nb11,
42-
const int nb12,
43-
const int nb13,
44-
const int nb21,
45-
const int nb22,
46-
const int nb23,
47-
const int ne0,
48-
const int ne1,
49-
const int ne2,
50-
const int ne3) {
24+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
25+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
26+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
27+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
28+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
29+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
30+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
5131
#ifdef FLASH_ATTN_AVAILABLE
5232

5333
// Skip unused kernel variants for faster compilation:
@@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32(
6646
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
6747
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
6848
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
69-
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
70-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
49+
GGML_UNUSED(nb23);
7150
NO_DEVICE_CODE;
7251
return;
7352
}
@@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
135114

136115
#pragma unroll
137116
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
138-
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
117+
const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
139118
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
140119
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
141120
}
@@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
231210
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
232211
const int i = i0 + threadIdx.x;
233212

234-
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
235-
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
213+
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
214+
KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
215+
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
236216
}
237217
}
238218

@@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32(
312292
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
313293
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
314294
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
315-
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
316295
NO_DEVICE_CODE;
317296
#endif // FLASH_ATTN_AVAILABLE
318297
}

0 commit comments

Comments
 (0)