Skip to content

Commit 8a4bad5

Browse files
authored
llama: use sliding window for phi3 (#8627)
* use sliding window for phi3 * fix typo, "data_swa" -> "data" * [conver_hf_to_gguf.py] add phi3 sliding window
1 parent 68504f0 commit 8a4bad5

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,6 +2084,7 @@ def set_gguf_parameters(self):
20842084
self.gguf_writer.add_rope_dimension_count(rope_dims)
20852085
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
20862086
self.gguf_writer.add_file_type(self.ftype)
2087+
self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"]))
20872088

20882089
# write rope scaling for long context (128k) model
20892090
rope_scaling = self.find_hparam(['rope_scaling'], True)

src/llama.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4889,6 +4889,7 @@ static void llm_load_hparams(
48894889
} break;
48904890
case LLM_ARCH_PHI3:
48914891
{
4892+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
48924893
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
48934894

48944895
switch (hparams.n_layer) {
@@ -10748,7 +10749,7 @@ struct llm_build_context {
1074810749
struct ggml_tensor * inp_pos = build_inp_pos();
1074910750

1075010751
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
10751-
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
10752+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
1075210753

1075310754
for (int il = 0; il < n_layer; ++il) {
1075410755
auto residual = inpL;
@@ -10806,7 +10807,7 @@ struct llm_build_context {
1080610807

1080710808
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
1080810809
model.layers[il].wo, model.layers[il].bo,
10809-
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
10810+
Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
1081010811
}
1081110812

1081210813
if (il == n_layer - 1) {
@@ -14013,18 +14014,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1401314014
"causal attention is not supported by this model"
1401414015
);
1401514016

14016-
if (lctx.inp_KQ_mask) {
14017+
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
1401714018
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
1401814019
if (cparams.causal_attn && !lctx.is_encoding) {
1401914020
const int64_t n_kv = kv_self.n;
1402014021
const int64_t n_tokens = batch.n_tokens;
1402114022

14022-
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
1402314023

14024-
float * data = (float *) lctx.inp_KQ_mask->data;
14024+
float * data = nullptr;
1402514025
float * data_swa = nullptr;
1402614026

14027+
if (lctx.inp_KQ_mask) {
14028+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
14029+
data = (float *) lctx.inp_KQ_mask->data;
14030+
}
14031+
1402714032
if (lctx.inp_KQ_mask_swa) {
14033+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
1402814034
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
1402914035
}
1403014036

@@ -14047,7 +14053,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1404714053
f = 0.0f;
1404814054
}
1404914055
}
14050-
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
14056+
14057+
if (data) {
14058+
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
14059+
}
1405114060

1405214061
// may need to cut off old tokens for sliding window
1405314062
if (data_swa) {
@@ -14059,9 +14068,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1405914068
}
1406014069
}
1406114070

14062-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
14063-
for (int j = 0; j < n_kv; ++j) {
14064-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
14071+
if (data) {
14072+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
14073+
for (int j = 0; j < n_kv; ++j) {
14074+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
14075+
}
14076+
}
14077+
}
14078+
14079+
if (data_swa) {
14080+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
14081+
for (int j = 0; j < n_kv; ++j) {
14082+
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
14083+
}
1406514084
}
1406614085
}
1406714086
}

0 commit comments

Comments
 (0)