Skip to content

Commit 3e6d1e4

Browse files
committed
ggml : FA supports F32 V
1 parent 8ca6e1c commit 3e6d1e4

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
67216721
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
67226722
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
67236723

6724-
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
6725-
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
6724+
GGML_ASSERT( q_to_vec_dot && "fattn: unsupported K-type");
6725+
GGML_ASSERT(v->type == GGML_TYPE_F32 || v_to_float && "fattn: unsupported V-type");
67266726

67276727
// loop over n_batch and n_head
67286728
for (int ir = ir0; ir < ir1; ++ir) {
@@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
68186818
vs = expf(s - M);
68196819
}
68206820

6821-
v_to_float(v_data, V32, DV);
6822-
68236821
// V += v*expf(s - M)
6824-
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
6822+
if (v_to_float) {
6823+
v_to_float(v_data, V32, DV);
6824+
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
6825+
} else {
6826+
// V is F32
6827+
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
6828+
}
68256829
}
68266830

68276831
S = S*ms + vs; // scale and increment sum with partial sum

0 commit comments

Comments
 (0)