Skip to content

Commit 2df810a

Browse files
committed
vulkan: Use unclamped loads for flash attention mask
nem1 must be a multiple of GGML_KQ_MASK_PAD, and GGML_KQ_MASK_PAD is a multiple of the number of rows in the matrix. The KV dim is a multiple of the number of columns for the aligned shader.
1 parent 810dc62 commit 2df810a

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

Diff for: ggml/src/ggml-vulkan/ggml-vulkan.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
17611761
// can't use 256 for D==80.
17621762
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
17631763
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1764+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
1765+
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
17641766
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
17651767
};
17661768

@@ -5294,6 +5296,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
52945296
// the "aligned" shader variant will forcibly align strides, for performance
52955297
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
52965298

5299+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
5300+
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
5301+
52975302
vk_pipeline pipeline = pipelines[aligned];
52985303
assert(pipeline);
52995304

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ void main() {
256256
}
257257

258258
if (p.mask != 0) {
259-
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
259+
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
260260
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
261261
// When using grouped query attention, all rows use the same mask.
262262
if (p.gqa_ratio > 1) {

0 commit comments

Comments
 (0)