Skip to content

vulkan: enable coopmat2 FA gqa and split_k optimizations more often #12931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5531,7 +5531,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;

if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
Expand All @@ -5544,8 +5544,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t split_kv = KV;
uint32_t split_k = 1;

if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
GGML_ASSERT(workgroups_x == 1);
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
// Try to run two workgroups per SM.
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
if (split_k > 1) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in A
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
const uint32_t h = iq2 + (r % p.gqa_ratio);

const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
Expand Down
4 changes: 3 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4532,7 +4532,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {

for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
for (int nr : { 1, 4, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
}
}
}

Expand Down
Loading