@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
408
408
const int stride_K,
409
409
const int stride_V,
410
410
const int stride_mask,
411
- const int jt,
412
411
half2 * const __restrict__ tile_Q,
413
412
half2 * const __restrict__ tile_K,
414
413
half2 * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
455
454
cp_async_wait_all ();
456
455
__syncthreads ();
457
456
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458
- (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
457
+ (V_h2 + int64_t ( k_VKQ_0) *stride_V, tile_V, nbatch_V2, stride_V);
459
458
} else {
460
459
constexpr bool use_cp_async = nstages == 1 ;
461
460
if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
471
470
if (nstages <= 1 ) {
472
471
constexpr bool use_cp_async = nstages == 1 ;
473
472
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474
- (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
473
+ (K_h2 + int64_t ( k_VKQ_0) *stride_K + k0_start, tile_K, k0_diff, stride_K);
475
474
if (use_cp_async) {
476
475
cp_async_wait_all ();
477
476
}
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
715
714
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2 , tile_mask, stride_mask);
716
715
}
717
716
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
717
+ (K_h2 + int64_t (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
719
718
}
720
719
}
721
720
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
732
731
if (nstages <= 1 && i0_start < reusable_cutoff) {
733
732
constexpr bool use_cp_async = nstages == 1 ;
734
733
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735
- (V_h2 + k_VKQ_0*stride_V + i0_start/2 , tile_V, i0_diff/2 , stride_V);
734
+ (V_h2 + int64_t ( k_VKQ_0) *stride_V + i0_start/2 , tile_V, i0_diff/2 , stride_V);
736
735
if (use_cp_async) {
737
736
cp_async_wait_all ();
738
737
}
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
771
770
GGML_UNUSED (mask_h2); GGML_UNUSED (dstk); GGML_UNUSED (dstk_fixup);
772
771
GGML_UNUSED (scale); GGML_UNUSED (slope); GGML_UNUSED (logit_softcap);
773
772
GGML_UNUSED (ne01); GGML_UNUSED (ne02); GGML_UNUSED (stride_K); GGML_UNUSED (stride_V);
774
- GGML_UNUSED (stride_mask); GGML_UNUSED (jt); GGML_UNUSED (tile_K);
775
- GGML_UNUSED (stride_mask); GGML_UNUSED (jt); GGML_UNUSED (tile_K);
773
+ GGML_UNUSED (stride_mask); GGML_UNUSED (tile_K);
776
774
GGML_UNUSED (tile_V); GGML_UNUSED (tile_mask); GGML_UNUSED (Q_B);
777
775
GGML_UNUSED (VKQ_C); GGML_UNUSED (KQ_max); GGML_UNUSED (KQ_rowsum);
778
776
GGML_UNUSED (kb0); GGML_UNUSED (tile_Q);
@@ -920,21 +918,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
920
918
(mask_h2 + kb0_start*c::nbatch_fa/2 , tile_mask, stride_mask);
921
919
}
922
920
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
923
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921
+ (K_h2 + int64_t ( kb0_start) *c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
924
922
}
925
923
926
924
// Iterate over ne11 == previous tokens:
927
925
for (int kb0 = kb0_start; kb0 < kb0_stop-1 ; ++kb0) {
928
926
constexpr bool last_iter = false ;
929
927
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
930
928
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
931
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
929
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
932
930
}
933
931
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
934
932
constexpr bool last_iter = true ;
935
933
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
936
934
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
937
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1 );
935
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1 );
938
936
}
939
937
940
938
// With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
1214
1212
const float m1,
1215
1213
const uint32_t n_head_log2,
1216
1214
const float logit_softcap,
1217
- const int ne00,
1218
- const int ne01,
1219
- const int ne02,
1220
- const int ne03,
1221
- const int ne10,
1222
- const int ne11,
1223
- const int ne12,
1224
- const int ne13,
1225
- const int ne31,
1226
- const int ne32,
1227
- const int ne33,
1228
- const int nb31,
1229
- const int nb32,
1230
- const int nb33,
1231
- const int nb01,
1232
- const int nb02,
1233
- const int nb03,
1234
- const int nb11,
1235
- const int nb12,
1236
- const int nb13,
1237
- const int nb21,
1238
- const int nb22,
1239
- const int nb23,
1240
- const int ne0,
1241
- const int ne1,
1242
- const int ne2,
1243
- const int ne3) {
1215
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1216
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
1217
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1218
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
1219
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
1220
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
1221
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1244
1222
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1245
1223
1246
1224
// Skip unused kernel variants for faster compilation:
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
1359
1337
GGML_UNUSED (ne11); GGML_UNUSED (ne12); GGML_UNUSED (ne13); GGML_UNUSED (ne31); GGML_UNUSED (ne32);
1360
1338
GGML_UNUSED (nb31); GGML_UNUSED (nb32); GGML_UNUSED (nb01); GGML_UNUSED (nb02); GGML_UNUSED (nb03);
1361
1339
GGML_UNUSED (nb11); GGML_UNUSED (nb12); GGML_UNUSED (nb13); GGML_UNUSED (nb21);
1362
- GGML_UNUSED (nb22); GGML_UNUSED (nb23); GGML_UNUSED (ne0); GGML_UNUSED (ne1);
1363
- GGML_UNUSED (ne2); GGML_UNUSED (ne3);
1340
+ GGML_UNUSED (nb22); GGML_UNUSED (nb23);
1364
1341
NO_DEVICE_CODE;
1365
1342
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1366
1343
}
0 commit comments