@@ -342,6 +342,7 @@ struct vk_device_struct {
342
342
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
343
343
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
344
344
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
345
+ vk_pipeline pipeline_flash_attn_split_k_reduce;
345
346
346
347
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
347
348
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@@ -493,6 +494,8 @@ struct vk_flash_attn_push_constants {
493
494
float m1;
494
495
495
496
uint32_t gqa_ratio;
497
+ uint32_t split_kv;
498
+ uint32_t k_num;
496
499
};
497
500
498
501
struct vk_op_push_constants {
@@ -1465,7 +1468,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1465
1468
1466
1469
// small rows, large cols
1467
1470
if (small_rows) {
1468
- return {flash_attention_num_small_rows, 128 };
1471
+ return {flash_attention_num_small_rows, 64 };
1469
1472
}
1470
1473
// small cols to reduce register count
1471
1474
if (ggml_is_quantized(type) || D == 256) {
@@ -2269,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2269
2272
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2270
2273
2271
2274
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2275
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2272
2276
2273
2277
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2274
2278
if (device->subgroup_add && device->subgroup_require_full_support) {
@@ -5309,9 +5313,38 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5309
5313
workgroups_y /= N;
5310
5314
}
5311
5315
5316
+ uint32_t split_kv = KV;
5317
+ uint32_t split_k = 1;
5318
+
5319
+ if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
5320
+ GGML_ASSERT(workgroups_x == 1);
5321
+ // Try to run two workgroups per SM.
5322
+ split_k = ctx->device->shader_core_count * 2 / workgroups_y;
5323
+ if (split_k > 1) {
5324
+ // Try to evenly split KV into split_k chunks, but it needs to be a multiple
5325
+ // of "align", so recompute split_k based on that.
5326
+ split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
5327
+ split_k = CEIL_DIV(KV, split_kv);
5328
+ workgroups_x = split_k;
5329
+ }
5330
+ }
5331
+
5332
+ // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
5333
+ // and the per-row m and L values (ne1 rows).
5334
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
5335
+ if (split_k_size > ctx->device->max_memory_allocation_size) {
5336
+ GGML_ABORT("Requested preallocation size is too large");
5337
+ }
5338
+ if (ctx->prealloc_size_split_k < split_k_size) {
5339
+ ctx->prealloc_size_split_k = split_k_size;
5340
+ }
5341
+
5312
5342
if (dryrun) {
5313
5343
// Request descriptor sets
5314
5344
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5345
+ if (split_k > 1) {
5346
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
5347
+ }
5315
5348
return;
5316
5349
}
5317
5350
@@ -5332,8 +5365,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5332
5365
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5333
5366
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5334
5367
5335
- ggml_vk_sync_buffers(subctx);
5336
-
5337
5368
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
5338
5369
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
5339
5370
@@ -5398,16 +5429,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5398
5429
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5399
5430
nbm1,
5400
5431
scale, max_bias, logit_softcap,
5401
- mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
5402
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5403
- {
5404
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5405
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5406
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5407
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5408
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5409
- },
5410
- sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5432
+ mask != nullptr, n_head_log2, m0, m1,
5433
+ gqa_ratio, split_kv, split_k };
5434
+
5435
+ ggml_vk_sync_buffers(subctx);
5436
+
5437
+ if (split_k > 1) {
5438
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5439
+ {
5440
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5441
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5442
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5443
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5444
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5445
+ },
5446
+ // We only use split_k when group query attention is enabled, which means
5447
+ // there's no more than one tile of rows (i.e. workgroups_x would have been
5448
+ // one). We reuse workgroups_x to mean the number of splits, so we need to
5449
+ // cancel out the divide by wg_denoms[0].
5450
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
5451
+
5452
+ ggml_vk_sync_buffers(subctx);
5453
+ const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
5454
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
5455
+ {
5456
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5457
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5458
+ },
5459
+ pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
5460
+ } else {
5461
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5462
+ {
5463
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5464
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5465
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5466
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5467
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5468
+ },
5469
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5470
+ }
5411
5471
}
5412
5472
5413
5473
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
0 commit comments