Skip to content

Commit 58d0fc2

Browse files
committed
server : add test that exercises embeddings with FA enabled
ggml-ci
1 parent 7cb9ae0 commit 58d0fc2

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

examples/server/tests/unit/test_embedding.py

+20
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ def test_embedding_multiple():
4949
assert len(d['embedding']) > 1
5050

5151

52+
def test_embedding_multiple_with_fa():
53+
server = ServerPreset.bert_bge_small_with_fa()
54+
server.pooling = 'last'
55+
server.start()
56+
# one of these should trigger the FA branch (i.e. context size % 256 == 0)
57+
res = server.make_request("POST", "/v1/embeddings", data={
58+
"input": [
59+
"a "*253,
60+
"b "*254,
61+
"c "*255,
62+
"d "*256,
63+
],
64+
})
65+
assert res.status_code == 200
66+
assert len(res.body['data']) == 4
67+
for d in res.body['data']:
68+
assert 'embedding' in d
69+
assert len(d['embedding']) > 1
70+
71+
5272
@pytest.mark.parametrize(
5373
"input,is_multi_prompt",
5474
[

examples/server/tests/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,21 @@ def bert_bge_small() -> ServerProcess:
323323
server.server_embeddings = True
324324
return server
325325

326+
@staticmethod
327+
def bert_bge_small_with_fa() -> ServerProcess:
328+
server = ServerProcess()
329+
server.model_hf_repo = "ggml-org/models"
330+
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
331+
server.model_alias = "bert-bge-small"
332+
server.n_ctx = 1024
333+
server.n_batch = 300
334+
server.n_ubatch = 300
335+
server.n_slots = 2
336+
server.fa = True
337+
server.seed = 42
338+
server.server_embeddings = True
339+
return server
340+
326341
@staticmethod
327342
def tinyllama_infill() -> ServerProcess:
328343
server = ServerProcess()

ggml/src/ggml-metal/ggml-metal.m

+5
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13451345
case GGML_OP_ARANGE:
13461346
return true;
13471347
case GGML_OP_FLASH_ATTN_EXT:
1348+
if (op->src[0]->ne[0] == 32) {
1349+
// head size == 32 (e.g. bert-bge-small)
1350+
// TODO: not sure if it is worth adding kernels for this size
1351+
return false;
1352+
}
13481353
if (op->src[1]->type != op->src[2]->type) {
13491354
return false;
13501355
}

0 commit comments

Comments
 (0)