Skip to content

Commit fa664af

Browse files
vlovichmglambda
authored andcommitted
llama : expose llama_model_n_head_kv in the API (ggml-org#11997)
It's useful to be able to have this from the library layer as it's a key parameter of the model (e.g. to figure out how much KV cache memory is needed).
1 parent 4cce5e9 commit fa664af

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ extern "C" {
477477
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
478478
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
479479
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
480+
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
480481

481482
// Get the model's RoPE frequency scaling factor
482483
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3838,6 +3838,10 @@ int32_t llama_model_n_head(const struct llama_model * model) {
38383838
return model->hparams.n_head();
38393839
}
38403840

3841+
int32_t llama_model_n_head_kv(const struct llama_model * model) {
3842+
return model->hparams.n_head_kv();
3843+
}
3844+
38413845
// deprecated
38423846
int32_t llama_n_ctx_train(const struct llama_model * model) {
38433847
return llama_model_n_ctx_train(model);

0 commit comments

Comments
 (0)