@@ -12238,23 +12238,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12238
12238
const int ith = params->ith;
12239
12239
const int nth = params->nth;
12240
12240
12241
- const int64_t D = neq0;
12242
- const int64_t N = neq1;
12241
+ const int64_t DK = nek0;
12242
+ const int64_t DV = nev0;
12243
+ const int64_t N = neq1;
12243
12244
12244
- GGML_ASSERT(ne0 == D );
12245
+ GGML_ASSERT(ne0 == DV );
12245
12246
GGML_ASSERT(ne2 == N);
12246
12247
12247
12248
// input tensor rows must be contiguous
12248
12249
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
12249
12250
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
12250
12251
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
12251
12252
12252
- GGML_ASSERT(neq0 == D );
12253
- GGML_ASSERT(nek0 == D );
12254
- GGML_ASSERT(nev0 == D );
12253
+ GGML_ASSERT(neq0 == DK );
12254
+ GGML_ASSERT(nek0 == DK );
12255
+ GGML_ASSERT(nev0 == DV );
12255
12256
12256
12257
GGML_ASSERT(neq1 == N);
12257
- GGML_ASSERT(nev0 == D);
12258
12258
12259
12259
// dst cannot be transposed or permuted
12260
12260
GGML_ASSERT(nb0 == sizeof(float));
@@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12320
12320
float S = 0.0f; // sum
12321
12321
float M = -INFINITY; // maximum KQ value
12322
12322
12323
- float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324
- float * V32 = (VKQ32 + 1*D ); // (temporary) FP32 V buffer
12325
- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D ); // (temporary) FP16 VKQ accumulator
12326
- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D ); // (temporary) buffer for Q converted to quantized/FP16
12323
+ float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324
+ float * V32 = (VKQ32 + 1*DV ); // (temporary) FP32 V buffer
12325
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV ); // (temporary) FP16 VKQ accumulator
12326
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV ); // (temporary) buffer for Q converted to quantized/FP16
12327
12327
12328
12328
if (v->type == GGML_TYPE_F16) {
12329
- memset(VKQ16, 0, D *sizeof(ggml_fp16_t));
12329
+ memset(VKQ16, 0, DV *sizeof(ggml_fp16_t));
12330
12330
} else {
12331
- memset(VKQ32, 0, D *sizeof(float));
12331
+ memset(VKQ32, 0, DV *sizeof(float));
12332
12332
}
12333
12333
12334
12334
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12342
12342
const int iv2 = iq2 / rv2;
12343
12343
12344
12344
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
12345
- q_to_vec_dot(pq, Q_q, D );
12345
+ q_to_vec_dot(pq, Q_q, DK );
12346
12346
12347
12347
// online softmax / attention
12348
12348
// loop over n_kv and n_head_kv
@@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12356
12356
float s; // KQ value
12357
12357
12358
12358
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
12359
- kq_vec_dot(D , &s, 0, k_data, 0, Q_q, 0, 1);
12359
+ kq_vec_dot(DK , &s, 0, k_data, 0, Q_q, 0, 1);
12360
12360
12361
12361
s = s*scale; // scale KQ value
12362
12362
@@ -12380,45 +12380,45 @@ static void ggml_compute_forward_flash_attn_ext_f16(
12380
12380
ms = expf(Mold - M);
12381
12381
12382
12382
// V = V*expf(Mold - M)
12383
- ggml_vec_scale_f16(D , VKQ16, ms);
12383
+ ggml_vec_scale_f16(DV , VKQ16, ms);
12384
12384
} else {
12385
12385
// no new maximum, ms == 1.0f, vs != 1.0f
12386
12386
vs = expf(s - M);
12387
12387
}
12388
12388
12389
12389
// V += v*expf(s - M)
12390
- ggml_vec_mad_f16(D , VKQ16, (const ggml_fp16_t *) v_data, vs);
12390
+ ggml_vec_mad_f16(DV , VKQ16, (const ggml_fp16_t *) v_data, vs);
12391
12391
} else {
12392
12392
if (s > M) {
12393
12393
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12394
12394
M = s;
12395
12395
ms = expf(Mold - M);
12396
12396
12397
12397
// V = V*expf(Mold - M)
12398
- ggml_vec_scale_f32(D , VKQ32, ms);
12398
+ ggml_vec_scale_f32(DV , VKQ32, ms);
12399
12399
} else {
12400
12400
// no new maximum, ms == 1.0f, vs != 1.0f
12401
12401
vs = expf(s - M);
12402
12402
}
12403
12403
12404
- v_to_float(v_data, V32, D );
12404
+ v_to_float(v_data, V32, DV );
12405
12405
12406
12406
// V += v*expf(s - M)
12407
- ggml_vec_mad_f32(D , VKQ32, V32, vs);
12407
+ ggml_vec_mad_f32(DV , VKQ32, V32, vs);
12408
12408
}
12409
12409
12410
12410
S = S*ms + vs; // scale and increment sum with partial sum
12411
12411
}
12412
12412
12413
12413
if (v->type == GGML_TYPE_F16) {
12414
- for (int64_t d = 0; d < D ; ++d) {
12414
+ for (int64_t d = 0; d < DV ; ++d) {
12415
12415
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
12416
12416
}
12417
12417
}
12418
12418
12419
12419
// V /= S
12420
12420
const float S_inv = 1.0f/S;
12421
- ggml_vec_scale_f32(D , VKQ32, S_inv);
12421
+ ggml_vec_scale_f32(DV , VKQ32, S_inv);
12422
12422
12423
12423
// dst indices
12424
12424
const int i1 = iq1;
@@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
15277
15277
size_t cur = 0;
15278
15278
15279
15279
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
15280
-
15281
15280
switch (node->op) {
15282
15281
case GGML_OP_CPY:
15283
15282
case GGML_OP_DUP:
@@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
15386
15385
} break;
15387
15386
case GGML_OP_FLASH_ATTN_EXT:
15388
15387
{
15389
- const int64_t ne00 = node->src[0]->ne[0]; // D
15388
+ const int64_t ne10 = node->src[1]->ne[0]; // DK
15389
+ const int64_t ne20 = node->src[2]->ne[0]; // DV
15390
15390
15391
- cur = 3* sizeof(float)*ne00* n_tasks; // 3x head size/ thread
15391
+ cur = sizeof(float)*(1*ne10 + 2*ne20)* n_tasks; // 1x head size K + 2x head size V (per thread)
15392
15392
} break;
15393
15393
case GGML_OP_FLASH_ATTN_BACK:
15394
15394
{
0 commit comments