File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
6721
6721
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
6722
6722
ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
6723
6723
6724
- GGML_ASSERT (q_to_vec_dot && " fattn: unsupported K-type" );
6725
- GGML_ASSERT (v_to_float && " fattn: unsupported V-type" );
6724
+ GGML_ASSERT ( q_to_vec_dot && " fattn: unsupported K-type" );
6725
+ GGML_ASSERT (v-> type == GGML_TYPE_F32 || v_to_float && " fattn: unsupported V-type" );
6726
6726
6727
6727
// loop over n_batch and n_head
6728
6728
for (int ir = ir0; ir < ir1; ++ir) {
@@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
6818
6818
vs = expf (s - M);
6819
6819
}
6820
6820
6821
- v_to_float (v_data, V32, DV);
6822
-
6823
6821
// V += v*expf(s - M)
6824
- ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
6822
+ if (v_to_float) {
6823
+ v_to_float (v_data, V32, DV);
6824
+ ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
6825
+ } else {
6826
+ // V is F32
6827
+ ggml_vec_mad_f32 (DV, VKQ32, (const float *) v_data, vs);
6828
+ }
6825
6829
}
6826
6830
6827
6831
S = S*ms + vs; // scale and increment sum with partial sum
You can’t perform that action at this time.
0 commit comments