Skip to content

Commit 810dc62

Browse files
committed
vulkan: Implement split_k for coopmat2 flash attention.
When using group query attention, we have one workgroup per KV batch and this can be very few workgroups (e.g. just 8 in some models). Enable split_k to spread the work across SMs. This helps a lot when the KV cache is large.
1 parent 99a3792 commit 810dc62

File tree

5 files changed

+176
-16
lines changed

5 files changed

+176
-16
lines changed

Diff for: ggml/src/ggml-vulkan/ggml-vulkan.cpp

+73-13
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ struct vk_device_struct {
342342
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
343343
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
344344
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
345+
vk_pipeline pipeline_flash_attn_split_k_reduce;
345346

346347
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
347348
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@@ -493,6 +494,8 @@ struct vk_flash_attn_push_constants {
493494
float m1;
494495

495496
uint32_t gqa_ratio;
497+
uint32_t split_kv;
498+
uint32_t k_num;
496499
};
497500

498501
struct vk_op_push_constants {
@@ -1465,7 +1468,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
14651468

14661469
// small rows, large cols
14671470
if (small_rows) {
1468-
return {flash_attention_num_small_rows, 128};
1471+
return {flash_attention_num_small_rows, 64};
14691472
}
14701473
// small cols to reduce register count
14711474
if (ggml_is_quantized(type) || D == 256) {
@@ -2269,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
22692272
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
22702273

22712274
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2275+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
22722276

22732277
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
22742278
if (device->subgroup_add && device->subgroup_require_full_support) {
@@ -5309,9 +5313,38 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
53095313
workgroups_y /= N;
53105314
}
53115315

5316+
uint32_t split_kv = KV;
5317+
uint32_t split_k = 1;
5318+
5319+
if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
5320+
GGML_ASSERT(workgroups_x == 1);
5321+
// Try to run two workgroups per SM.
5322+
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
5323+
if (split_k > 1) {
5324+
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
5325+
// of "align", so recompute split_k based on that.
5326+
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
5327+
split_k = CEIL_DIV(KV, split_kv);
5328+
workgroups_x = split_k;
5329+
}
5330+
}
5331+
5332+
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
5333+
// and the per-row m and L values (ne1 rows).
5334+
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
5335+
if (split_k_size > ctx->device->max_memory_allocation_size) {
5336+
GGML_ABORT("Requested preallocation size is too large");
5337+
}
5338+
if (ctx->prealloc_size_split_k < split_k_size) {
5339+
ctx->prealloc_size_split_k = split_k_size;
5340+
}
5341+
53125342
if (dryrun) {
53135343
// Request descriptor sets
53145344
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5345+
if (split_k > 1) {
5346+
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
5347+
}
53155348
return;
53165349
}
53175350

@@ -5332,8 +5365,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
53325365
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
53335366
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
53345367

5335-
ggml_vk_sync_buffers(subctx);
5336-
53375368
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
53385369
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
53395370

@@ -5398,16 +5429,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
53985429
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
53995430
nbm1,
54005431
scale, max_bias, logit_softcap,
5401-
mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
5402-
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5403-
{
5404-
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5405-
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5406-
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5407-
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5408-
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5409-
},
5410-
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5432+
mask != nullptr, n_head_log2, m0, m1,
5433+
gqa_ratio, split_kv, split_k };
5434+
5435+
ggml_vk_sync_buffers(subctx);
5436+
5437+
if (split_k > 1) {
5438+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5439+
{
5440+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5441+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5442+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5443+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5444+
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5445+
},
5446+
// We only use split_k when group query attention is enabled, which means
5447+
// there's no more than one tile of rows (i.e. workgroups_x would have been
5448+
// one). We reuse workgroups_x to mean the number of splits, so we need to
5449+
// cancel out the divide by wg_denoms[0].
5450+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
5451+
5452+
ggml_vk_sync_buffers(subctx);
5453+
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
5454+
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
5455+
{
5456+
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5457+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5458+
},
5459+
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
5460+
} else {
5461+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5462+
{
5463+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5464+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5465+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5466+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5467+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5468+
},
5469+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5470+
}
54115471
}
54125472

54135473
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

+37-3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ layout (push_constant) uniform parameter {
6363
float m1;
6464

6565
uint32_t gqa_ratio;
66+
uint32_t split_kv;
67+
uint32_t k_num;
6668
} p;
6769

6870
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
@@ -116,6 +118,16 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
116118
return elem;
117119
}
118120

121+
// Store column zero. This is used to save per-row m and L values for split_k.
122+
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
123+
{
124+
if (r < N && c == 0) {
125+
uint32_t offset = iq2 + r;
126+
data_o[o_offset + offset] = D_TYPE(elem);
127+
}
128+
return elem;
129+
}
130+
119131
// Load the slope matrix, indexed by Q's dimension 2.
120132
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
121133
{
@@ -135,10 +147,18 @@ void main() {
135147
const uint32_t N = p.N;
136148
const uint32_t KV = p.KV;
137149

150+
uint32_t i = gl_WorkGroupID.x;
151+
uint32_t split_k_index = 0;
152+
153+
if (p.k_num > 1) {
154+
i = 0;
155+
split_k_index = gl_WorkGroupID.x;
156+
}
157+
138158
const uint32_t Tr = CEIL_DIV(N, Br);
139-
const uint32_t Tc = CEIL_DIV(KV, Bc);
140159

141-
const uint32_t i = gl_WorkGroupID.x;
160+
const uint32_t start_j = split_k_index * p.split_kv / Bc;
161+
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
142162

143163
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
144164
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
@@ -218,7 +238,7 @@ void main() {
218238
}
219239

220240
[[dont_unroll]]
221-
for (uint32_t j = 0; j < Tc; ++j) {
241+
for (uint32_t j = start_j; j < end_j; ++j) {
222242

223243
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
224244

@@ -312,6 +332,20 @@ void main() {
312332
O = coopMatMulAdd(P_A, V, O);
313333
}
314334

335+
// If there is split_k, then the split_k resolve shader does the final
336+
// division by L. Store the intermediate O value and per-row m and L values.
337+
if (p.k_num > 1) {
338+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
339+
340+
uint32_t o_offset = D * p.ne1 * split_k_index;
341+
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
342+
343+
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
344+
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
345+
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
346+
return;
347+
}
348+
315349
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
316350

317351
// resize L by using smear/reduce
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#define BLOCK_SIZE 32
6+
7+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
8+
9+
layout (binding = 0) readonly buffer A {float data_a[];};
10+
layout (binding = 1) writeonly buffer D {float data_d[];};
11+
12+
layout (push_constant) uniform parameter {
13+
uint D;
14+
uint N;
15+
uint k_num;
16+
} p;
17+
18+
void main() {
19+
// Each workgroup handles a row
20+
const uint n = gl_WorkGroupID.x;
21+
const uint tid = gl_LocalInvocationID.x;
22+
23+
uint D = p.D;
24+
uint N = p.N;
25+
uint k_num = p.k_num;
26+
27+
uint l_offset = D * N * k_num + n;
28+
uint m_offset = D * N * k_num + N + n;
29+
uint lm_stride = N * 2;
30+
31+
// Compute the max m value for the row
32+
float m_max = -1.0/0.0;
33+
[[unroll]] for (uint k = 0; k < k_num; ++k) {
34+
float m = data_a[m_offset + k * lm_stride];
35+
m_max = max(m_max, m);
36+
}
37+
38+
// Compute L based on m_max
39+
float L = 0;
40+
[[unroll]] for (uint k = 0; k < k_num; ++k) {
41+
float l = data_a[l_offset + k * lm_stride];
42+
float m = data_a[m_offset + k * lm_stride];
43+
L += exp(m - m_max) * l;
44+
}
45+
46+
L = 1.0 / L;
47+
48+
// Scale and sum the O contributions based on m_max and store the result to memory
49+
for (uint d = tid; d < D; d += BLOCK_SIZE) {
50+
float O = 0.0;
51+
[[unroll]] for (uint k = 0; k < k_num; ++k) {
52+
uint o_offset = D * N * k + D * n + d;
53+
float m = data_a[m_offset + k * lm_stride];
54+
O += exp(m - m_max) * data_a[o_offset];
55+
}
56+
O *= L;
57+
data_d[D * n + d] = O;
58+
}
59+
}

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ void process_shaders() {
458458
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
459459

460460
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
461+
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
461462

462463
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
463464

Diff for: tests/test-backend-ops.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -4509,6 +4509,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
45094509
}
45104510
}
45114511

4512+
for (int kv : { 4096, 8192, 16384, }) {
4513+
for (int hs : { 64, 128, }) {
4514+
test_cases.emplace_back(new test_flash_attn_ext(hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
4515+
}
4516+
}
4517+
45124518
return test_cases;
45134519
}
45144520

0 commit comments

Comments
 (0)