Skip to content

Commit 7cb9ae0

Browse files
committed
graph : cast KV to F16 when the KV cache is not used
ggml-ci
1 parent 3e6d1e4 commit 7cb9ae0

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

examples/server_embd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ async def main():
1515
model_url = "http://127.0.0.1:6900"
1616
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
1717
url= f"{model_url}/embedding",
18-
json= {"content": str(0)*1024}
18+
json= {"content": "a "*1022}
1919
) for i in range(n)])
2020

2121
for response in responses:

src/llama-graph.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12151215
v = ggml_transpose(ctx0, v);
12161216
}
12171217

1218+
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1219+
if (k->type == GGML_TYPE_F32) {
1220+
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1221+
}
1222+
1223+
if (v->type == GGML_TYPE_F32) {
1224+
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1225+
}
1226+
12181227
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
12191228
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
12201229

0 commit comments

Comments
 (0)