Skip to content

Commit 27533e7

Browse files
committed
metal : improve FA + improve MoE (llama/12612)
* ggml : FA with different K, V head sizes (CPU) ggml-ci * metal : add FA with HS=192 * metal : extend FA to support different K and V head sizes ggml-ci * metal : add FA vector kernels for heads K 192 and V 128 ggml-ci * ggml : restrict op on other backends to equal head sizes ggml-ci * metal : optimize FA-vec kernel ggml-ci * metal : FA remove mq registers * metal : improve MoE mul_mat_id condition ggml-ci * metal : fix comments + remove unnecessary addition ggml-ci * metal : avoid too much shared memory usage with mul_mat_id ggml-ci
1 parent 1b81415 commit 27533e7

File tree

8 files changed

+883
-678
lines changed

8 files changed

+883
-678
lines changed

ggml/include/ggml.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,11 +1791,11 @@ extern "C" {
17911791

17921792
#define GGML_KQ_MASK_PAD 64
17931793

1794-
// q: [n_embd, n_batch, n_head, 1]
1795-
// k: [n_embd, n_kv, n_head_kv, 1]
1796-
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1797-
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1798-
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
1794+
// q: [n_embd_k, n_batch, n_head, 1]
1795+
// k: [n_embd_k, n_kv, n_head_kv, 1]
1796+
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1797+
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1798+
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
17991799
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
18001800
struct ggml_context * ctx,
18011801
struct ggml_tensor * q,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12238,23 +12238,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1223812238
const int ith = params->ith;
1223912239
const int nth = params->nth;
1224012240

12241-
const int64_t D = neq0;
12242-
const int64_t N = neq1;
12241+
const int64_t DK = nek0;
12242+
const int64_t DV = nev0;
12243+
const int64_t N = neq1;
1224312244

12244-
GGML_ASSERT(ne0 == D);
12245+
GGML_ASSERT(ne0 == DV);
1224512246
GGML_ASSERT(ne2 == N);
1224612247

1224712248
// input tensor rows must be contiguous
1224812249
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
1224912250
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
1225012251
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
1225112252

12252-
GGML_ASSERT(neq0 == D);
12253-
GGML_ASSERT(nek0 == D);
12254-
GGML_ASSERT(nev0 == D);
12253+
GGML_ASSERT(neq0 == DK);
12254+
GGML_ASSERT(nek0 == DK);
12255+
GGML_ASSERT(nev0 == DV);
1225512256

1225612257
GGML_ASSERT(neq1 == N);
12257-
GGML_ASSERT(nev0 == D);
1225812258

1225912259
// dst cannot be transposed or permuted
1226012260
GGML_ASSERT(nb0 == sizeof(float));
@@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1232012320
float S = 0.0f; // sum
1232112321
float M = -INFINITY; // maximum KQ value
1232212322

12323-
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324-
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
12325-
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
12326-
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
12323+
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324+
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
12325+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
12326+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
1232712327

1232812328
if (v->type == GGML_TYPE_F16) {
12329-
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
12329+
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
1233012330
} else {
12331-
memset(VKQ32, 0, D*sizeof(float));
12331+
memset(VKQ32, 0, DV*sizeof(float));
1233212332
}
1233312333

1233412334
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1234212342
const int iv2 = iq2 / rv2;
1234312343

1234412344
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
12345-
q_to_vec_dot(pq, Q_q, D);
12345+
q_to_vec_dot(pq, Q_q, DK);
1234612346

1234712347
// online softmax / attention
1234812348
// loop over n_kv and n_head_kv
@@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1235612356
float s; // KQ value
1235712357

1235812358
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
12359-
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
12359+
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
1236012360

1236112361
s = s*scale; // scale KQ value
1236212362

@@ -12380,45 +12380,45 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1238012380
ms = expf(Mold - M);
1238112381

1238212382
// V = V*expf(Mold - M)
12383-
ggml_vec_scale_f16(D, VKQ16, ms);
12383+
ggml_vec_scale_f16(DV, VKQ16, ms);
1238412384
} else {
1238512385
// no new maximum, ms == 1.0f, vs != 1.0f
1238612386
vs = expf(s - M);
1238712387
}
1238812388

1238912389
// V += v*expf(s - M)
12390-
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
12390+
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
1239112391
} else {
1239212392
if (s > M) {
1239312393
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
1239412394
M = s;
1239512395
ms = expf(Mold - M);
1239612396

1239712397
// V = V*expf(Mold - M)
12398-
ggml_vec_scale_f32(D, VKQ32, ms);
12398+
ggml_vec_scale_f32(DV, VKQ32, ms);
1239912399
} else {
1240012400
// no new maximum, ms == 1.0f, vs != 1.0f
1240112401
vs = expf(s - M);
1240212402
}
1240312403

12404-
v_to_float(v_data, V32, D);
12404+
v_to_float(v_data, V32, DV);
1240512405

1240612406
// V += v*expf(s - M)
12407-
ggml_vec_mad_f32(D, VKQ32, V32, vs);
12407+
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
1240812408
}
1240912409

1241012410
S = S*ms + vs; // scale and increment sum with partial sum
1241112411
}
1241212412

1241312413
if (v->type == GGML_TYPE_F16) {
12414-
for (int64_t d = 0; d < D; ++d) {
12414+
for (int64_t d = 0; d < DV; ++d) {
1241512415
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
1241612416
}
1241712417
}
1241812418

1241912419
// V /= S
1242012420
const float S_inv = 1.0f/S;
12421-
ggml_vec_scale_f32(D, VKQ32, S_inv);
12421+
ggml_vec_scale_f32(DV, VKQ32, S_inv);
1242212422

1242312423
// dst indices
1242412424
const int i1 = iq1;
@@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
1527715277
size_t cur = 0;
1527815278

1527915279
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
15280-
1528115280
switch (node->op) {
1528215281
case GGML_OP_CPY:
1528315282
case GGML_OP_DUP:
@@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
1538615385
} break;
1538715386
case GGML_OP_FLASH_ATTN_EXT:
1538815387
{
15389-
const int64_t ne00 = node->src[0]->ne[0]; // D
15388+
const int64_t ne10 = node->src[1]->ne[0]; // DK
15389+
const int64_t ne20 = node->src[2]->ne[0]; // DV
1539015390

15391-
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
15391+
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
1539215392
} break;
1539315393
case GGML_OP_FLASH_ATTN_BACK:
1539415394
{

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32323232
#ifndef FLASH_ATTN_AVAILABLE
32333233
return false;
32343234
#endif // FLASH_ATTN_AVAILABLE
3235+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3236+
// different head sizes of K and V are not supported yet
3237+
return false;
3238+
}
3239+
if (op->src[0]->ne[0] == 192) {
3240+
return false;
3241+
}
32353242
if (op->src[0]->ne[3] != 1) {
32363243
return false;
32373244
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,12 @@ typedef struct {
219219
int32_t ne11;
220220
int32_t ne_12_2; // assume K and V are same shape
221221
int32_t ne_12_3;
222-
uint64_t nb_12_1;
223-
uint64_t nb_12_2;
224-
uint64_t nb_12_3;
222+
uint64_t nb11;
223+
uint64_t nb12;
224+
uint64_t nb13;
225+
uint64_t nb21;
226+
uint64_t nb22;
227+
uint64_t nb23;
225228
uint64_t nb31;
226229
int32_t ne1;
227230
int32_t ne2;

0 commit comments

Comments
 (0)