Skip to content

Commit f649f69

Browse files
This PR fixes the causal mask whenseq_len_q!=seq_len_KV (#314)
This PR adopt (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1087) When the seq_len_q !=seq_len_kv --------- Co-authored-by: Muhammad Tanvir <[email protected]>
1 parent f36600c commit f649f69

File tree

3 files changed

+32
-18
lines changed

3 files changed

+32
-18
lines changed

applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,17 @@ class GemmUniversalAttention {
251251
continue;
252252
}
253253

254+
auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024)
255+
auto discard_seq_coord = seq_len_qo - offset; //1024
256+
auto full_tile_offset = seq_len_kv - offset; //0
257+
const int seq_coord = cute::min(seq_len_qo, blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) ;
258+
259+
const int seq_len = CausalMask ? full_tile_offset + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + QK_SG_M : seq_len_kv;
260+
const int nblock_limit = cute::ceil_div(seq_len, QK_BLK_N);
261+
if(CausalMask && seq_coord < discard_seq_coord ) { // 1024 =0
262+
continue;
263+
}
264+
254265
Tensor mQ_mkl = cute::get_pvc_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); //(m,k,l)
255266
Tensor mK_nkl = cute::get_pvc_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); //(n,k,l)
256267
Tensor mV_nkl = cute::get_pvc_tensor(make_shape(head_size_vo, seq_len_kv, (is_var_len ? 1 : batch) * num_head_kv)); //(n,k,l)
@@ -261,15 +272,7 @@ class GemmUniversalAttention {
261272
auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
262273
auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});
263274
auto gV = local_tile(mV_nk, TileShapePV{}, make_coord(_, blk_n_coord, _), Step<X, _1, _1>{});
264-
265-
const int seq_coord = cute::min(seq_len_qo, blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M);
266-
267-
const int causal_seq_len = cute::min(seq_len_kv, seq_coord) + QK_SG_M;
268-
const int non_causal_seq_len = seq_len_kv;
269-
270-
const int nblock_limit = CausalMask ? cute::ceil_div(causal_seq_len, QK_BLK_N)
271-
: cute::ceil_div(non_causal_seq_len, QK_BLK_N);
272-
275+
273276
auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, batch_coord);
274277

275278
auto tiled_prefetch_q = cute::prefetch_selector<Shape<Int<QK_BLK_M>, Int<QK_BLK_K>>, Num_SGs>(mainloop_params.gmem_tiled_copy_q);
@@ -361,7 +364,7 @@ class GemmUniversalAttention {
361364
int row_idx = m * Vec + seq_coord;
362365
CUTLASS_PRAGMA_UNROLL
363366
for (int row = 0; row < Vec; row++, row_idx++) { // 8
364-
if (col_idx > row_idx)
367+
if ((col_idx - full_tile_offset) > (row_idx - discard_seq_coord))
365368
tSr(row, m, n) = -INFINITY;
366369
}
367370
}

benchmarks/pvc/flash_attention_v2/benchmark_runner.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ template <class FMHAConfiguration> struct BenchmarkRunnerFMHA {
179179
int offset_o = 0;
180180
// loop over the batch dimension to compute the output
181181
// to avoid the risk of running out of device memory
182-
int q_group_size = num_heads_q/num_heads_kv;
182+
int q_group_size = num_heads_q / num_heads_kv;
183183
for (int b = 0; b < batch; b++) {
184184
if constexpr (isVarLen) {
185185
auto logical_problem_shape = cutlass::fmha::collective::apply_variable_length(problem_size, b);
@@ -218,12 +218,14 @@ template <class FMHAConfiguration> struct BenchmarkRunnerFMHA {
218218

219219
// delete this memory as it is no longer needed
220220
block_S.reset();
221-
221+
auto offset = cute::min(seq_len_qo, seq_len_kv);
222+
auto discard_seq_coord = seq_len_qo - offset;
223+
auto full_tile_offset = seq_len_kv - offset;
222224
if constexpr (Causal) {
223225
// apply mask to S
224226
for (int row = 0; row < seq_len_qo; row++) {
225227
for (int col = 0; col < seq_len_kv; col++) {
226-
if (col > row)
228+
if ((col - full_tile_offset) > (row - discard_seq_coord))
227229
host_S[col + row * seq_len_kv] = -INFINITY;
228230
}
229231
}
@@ -263,7 +265,11 @@ template <class FMHAConfiguration> struct BenchmarkRunnerFMHA {
263265
idx = row * seq_len_kv;
264266
sum_idx = row;
265267
for (int col = 0; col < seq_len_kv; col++, idx++) {
266-
host_S[idx] /= sum_vec[sum_idx];
268+
if(Causal && row < discard_seq_coord) {
269+
host_S[idx] = 0;
270+
} else {
271+
host_S[idx] /= sum_vec[sum_idx];
272+
}
267273
}
268274
}
269275

examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,14 @@ template <class GemmKernel, bool isVarLen> struct ExampleRunner {
232232

233233
// delete this memory as it is no longer needed
234234
block_S.reset();
235-
236-
// Apply upper-diagonal masking if required
235+
auto offset = cute::min(seq_len_qo, seq_len_kv);
236+
auto discard_seq_coord = seq_len_qo - offset;
237+
auto full_tile_offset = seq_len_kv - offset;
237238
if (is_causal) {
238239
// apply mask to S
239240
for (int row = 0; row < seq_len_qo; row++) {
240241
for (int col = 0; col < seq_len_kv; col++) {
241-
if (col > row)
242+
if ((col - full_tile_offset) > (row - discard_seq_coord))
242243
host_S[col + row * seq_len_kv] = -INFINITY;
243244
}
244245
}
@@ -278,7 +279,11 @@ template <class GemmKernel, bool isVarLen> struct ExampleRunner {
278279
idx = row * seq_len_kv;
279280
sum_idx = row;
280281
for (int col = 0; col < seq_len_kv; col++, idx++) {
281-
host_S[idx] /= sum_vec[sum_idx];
282+
if(is_causal && row < discard_seq_coord) {
283+
host_S[idx] = 0;
284+
} else {
285+
host_S[idx] /= sum_vec[sum_idx];
286+
}
282287
}
283288
}
284289

0 commit comments

Comments
 (0)