From 12b198fa23803138836e172e1d79915f3c8412b7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 2 Apr 2025 09:23:58 -0500 Subject: [PATCH] vulkan: Use unclamped loads for flash attention mask nem1 must be a multiple of GGML_KQ_MASK_PAD, and GGML_KQ_MASK_PAD is a multiple of the number of rows in the matrix. The KV dim is a multiple of the number of columns for the aligned shader. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +++++ ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f6cc28603448a..976feeff0793f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1787,6 +1787,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // can't use 256 for D==80. uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; }; @@ -5464,6 +5466,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index d78092000d839..eedbc6f8b0e9c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -256,7 +256,7 @@ void main() { } if (p.mask != 0) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); // When using grouped query attention, all rows use the same mask. if (p.gqa_ratio > 1) {