Skip to content

Commit e0dbec0

Browse files
authored
llama : refactor llama_context, llama_kv_cache, llm_build_context (ggml-org#12181)
* llama : refactor llama_context, llama_kv_cache, llm_build_context ggml-ci * graph : don't mutate the KV cache during defrag ggml-ci * context : reduce virtuals + remove test function ggml-ci * context : move interface implementation to source file + factory ggml-ci * graph : move KV cache build functions to llama_context impl ggml-ci * graph : remove model reference from build_pooling ggml-ci * graph : remove llama_model reference ggml-ci * kv_cache : provide rope factors ggml-ci * graph : rework inputs to use only unique_ptr, remove attn input abstraction ggml-ci * context : remove llama_context_i abstraction ggml-ci * context : clean-up ggml-ci * graph : clean-up ggml-ci * llama : remove redundant keywords (struct, enum) ggml-ci * model : adapt gemma3 ggml-ci * graph : restore same attention ops as on master ggml-ci * llama : remove TODO + fix indent ggml-ci
1 parent 2048b59 commit e0dbec0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+13785
-12072
lines changed

common/common.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -955,8 +955,8 @@ struct common_init_result common_init_from_params(common_params & params) {
955955
return iparams;
956956
}
957957

958-
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
959-
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
958+
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
959+
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
960960
params.ctx_shift = false;
961961
}
962962

@@ -1060,7 +1060,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10601060
if (llama_model_has_decoder(model)) {
10611061
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
10621062
}
1063-
llama_kv_cache_clear(lctx);
1063+
llama_kv_self_clear(lctx);
10641064
llama_synchronize(lctx);
10651065
llama_perf_context_reset(lctx);
10661066
}

common/speculative.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ llama_tokens common_speculative_gen_draft(
173173
result.reserve(params.n_draft);
174174

175175
if (reuse_n == 0) {
176-
llama_kv_cache_clear(ctx);
176+
llama_kv_self_clear(ctx);
177177

178178
prompt.clear();
179179
} else {
@@ -192,14 +192,14 @@ llama_tokens common_speculative_gen_draft(
192192
}
193193

194194
if (reuse_i > 0) {
195-
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
196-
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
195+
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
196+
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
197197

198198
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
199199
}
200200

201201
if (reuse_n < (int) prompt.size()) {
202-
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
202+
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
203203

204204
prompt.erase(prompt.begin() + reuse_n, prompt.end());
205205
}

examples/batched-bench/batched-bench.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
132132

133133
const auto t_pp_start = ggml_time_us();
134134

135-
llama_kv_cache_clear(ctx);
135+
llama_kv_self_clear(ctx);
136136

137137
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
138138
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +141,7 @@ int main(int argc, char ** argv) {
141141

142142
if (is_pp_shared) {
143143
for (int32_t i = 1; i < pl; ++i) {
144-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
144+
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
145145
}
146146
}
147147

examples/batched.swift/Sources/main.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
116116
}
117117

118118
for i in 1 ..< n_parallel {
119-
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
119+
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
120120
}
121121

122122
if n_parallel > 1 {

examples/cvector-generator/cvector-generator.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
342342
}
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345-
llama_kv_cache_clear(ctx);
345+
llama_kv_self_clear(ctx);
346346
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
347347
fprintf(stderr, "%s : failed to eval\n", __func__);
348348
return false;

examples/embedding/embedding.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
3838
const struct llama_model * model = llama_get_model(ctx);
3939

4040
// clear previous kv_cache values (irrelevant for embeddings)
41-
llama_kv_cache_clear(ctx);
41+
llama_kv_self_clear(ctx);
4242

4343
// run model
4444
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

examples/gritlm/gritlm.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4545
}
4646

4747
// clear previous kv_cache values (irrelevant for embeddings)
48-
llama_kv_cache_clear(ctx);
48+
llama_kv_self_clear(ctx);
4949
llama_set_embeddings(ctx, true);
5050
llama_set_causal_attn(ctx, false);
5151

@@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
102102

103103
llama_token eos_token = llama_vocab_eos(vocab);
104104

105-
llama_kv_cache_clear(ctx);
105+
llama_kv_self_clear(ctx);
106106
llama_set_embeddings(ctx, false);
107107
llama_set_causal_attn(ctx, true);
108108

examples/imatrix/imatrix.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
495495
const auto t_start = std::chrono::high_resolution_clock::now();
496496

497497
// clear the KV cache
498-
llama_kv_cache_clear(ctx);
498+
llama_kv_self_clear(ctx);
499499

500500
llama_batch batch = llama_batch_init(n_batch, 0, 1);
501501

examples/infill/infill.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ int main(int argc, char ** argv) {
332332
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
333333
n_past, n_left, n_ctx, params.n_keep, n_discard);
334334

335-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
335+
llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336+
llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
337337

338338
n_past -= n_discard;
339339

examples/llama-bench/llama-bench.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ int main(int argc, char ** argv) {
15781578

15791579
test t(inst, lmodel, ctx);
15801580

1581-
llama_kv_cache_clear(ctx);
1581+
llama_kv_self_clear(ctx);
15821582

15831583
// cool off before the test
15841584
if (params.delay) {
@@ -1618,7 +1618,7 @@ int main(int argc, char ** argv) {
16181618
}
16191619

16201620
for (int i = 0; i < params.reps; i++) {
1621-
llama_kv_cache_clear(ctx);
1621+
llama_kv_self_clear(ctx);
16221622

16231623
uint64_t t_start = get_time_ns();
16241624

examples/llama.android/llama/src/main/cpp/llama-android.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
194194
}
195195

196196
batch->logits[batch->n_tokens - 1] = true;
197-
llama_kv_cache_clear(context);
197+
llama_kv_self_clear(context);
198198

199199
const auto t_pp_start = ggml_time_us();
200200
if (llama_decode(context, *batch) != 0) {
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
206206

207207
LOGi("Benchmark text generation (tg)");
208208

209-
llama_kv_cache_clear(context);
209+
llama_kv_self_clear(context);
210210
const auto t_tg_start = ggml_time_us();
211211
for (i = 0; i < tg; i++) {
212212

@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
223223

224224
const auto t_tg_end = ggml_time_us();
225225

226-
llama_kv_cache_clear(context);
226+
llama_kv_self_clear(context);
227227

228228
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
229229
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
448448
extern "C"
449449
JNIEXPORT void JNICALL
450450
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
451-
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
451+
llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
452452
}

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ actor LlamaContext {
210210
}
211211
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
212212

213-
llama_kv_cache_clear(context)
213+
llama_kv_self_clear(context)
214214

215215
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
216216

@@ -223,7 +223,7 @@ actor LlamaContext {
223223

224224
// bench text generation
225225

226-
llama_kv_cache_clear(context)
226+
llama_kv_self_clear(context)
227227

228228
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
229229

@@ -242,7 +242,7 @@ actor LlamaContext {
242242

243243
let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
244244

245-
llama_kv_cache_clear(context)
245+
llama_kv_self_clear(context)
246246

247247
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
248248
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@@ -292,7 +292,7 @@ actor LlamaContext {
292292
func clear() {
293293
tokens_list.removeAll()
294294
temporary_invalid_cchars.removeAll()
295-
llama_kv_cache_clear(context)
295+
llama_kv_self_clear(context)
296296
}
297297

298298
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

examples/llava/gemma3-cli.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
309309
}
310310
if (line == "/clear") {
311311
ctx.n_past = 0;
312-
llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
312+
llama_kv_self_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
313313
LOG("Chat history cleared\n\n");
314314
continue;
315315
}

examples/lookahead/lookahead.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
9696
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
9797

9898
for (int s = 1; s < W + G + 1; ++s) {
99-
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
99+
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
100100
}
101101

102102
const auto t_enc_end = ggml_time_us();
@@ -438,17 +438,17 @@ int main(int argc, char ** argv) {
438438

439439
// KV cache management
440440
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
441-
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
441+
llama_kv_self_seq_rm(ctx, -1, n_past, -1);
442442

443443
if (seq_id_best != 0) {
444444
// if a verification token matched, we keep the best sequence and remove the rest
445445
// this leads to some KV cache fragmentation
446-
llama_kv_cache_seq_keep(ctx, seq_id_best);
447-
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
448-
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
446+
llama_kv_self_seq_keep(ctx, seq_id_best);
447+
llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
448+
llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
449449

450450
for (int s = 1; s < W + G + 1; ++s) {
451-
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
451+
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
452452
}
453453
}
454454
}

examples/lookup/lookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ int main(int argc, char ** argv){
192192

193193
// KV cache management
194194
// clean the cache of draft tokens that weren't accepted
195-
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
195+
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
196196

197197
common_batch_clear(batch_tgt);
198198
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

examples/main/main.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ int main(int argc, char ** argv) {
354354
}
355355

356356
// remove any "future" tokens that we might have inherited from the previous session
357-
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
357+
llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
358358
}
359359

360360
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -602,8 +602,8 @@ int main(int argc, char ** argv) {
602602
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
603603
n_past, n_left, n_ctx, params.n_keep, n_discard);
604604

605-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
606-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
605+
llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
606+
llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
607607

608608
n_past -= n_discard;
609609

@@ -626,9 +626,9 @@ int main(int argc, char ** argv) {
626626
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
627627
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
628628

629-
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
630-
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
631-
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
629+
llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd);
630+
llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
631+
llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
632632

633633
n_past -= bd;
634634

examples/parallel/parallel.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
202202

203203
// assign the system KV cache to all parallel sequences
204204
for (int32_t i = 1; i <= n_clients; ++i) {
205-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
205+
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
206206
}
207207

208208
LOG_INF("\n");
@@ -234,9 +234,9 @@ int main(int argc, char ** argv) {
234234
if (batch.n_tokens == 0) {
235235
// all sequences have ended - clear the entire KV cache
236236
for (int i = 1; i <= n_clients; ++i) {
237-
llama_kv_cache_seq_rm(ctx, i, -1, -1);
237+
llama_kv_self_seq_rm(ctx, i, -1, -1);
238238
// but keep the system prompt
239-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
239+
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
240240
}
241241

242242
LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -372,8 +372,8 @@ int main(int argc, char ** argv) {
372372
}
373373

374374
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
375-
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
376-
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
375+
llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1);
376+
llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
377377

378378
const auto t_main_end = ggml_time_us();
379379

examples/passkey/passkey.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ int main(int argc, char ** argv) {
133133
const int ib = i/n_batch - 1;
134134
const int bd = n_batch_grp*(n_grp - 1);
135135

136-
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
137-
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
138-
llama_kv_cache_update (ctx);
136+
llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
137+
llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
138+
llama_kv_self_update (ctx);
139139

140-
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
140+
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
141141
}
142142

143143
common_batch_clear(batch);
@@ -167,12 +167,12 @@ int main(int argc, char ** argv) {
167167

168168
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
169169

170-
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
171-
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
172-
//llama_kv_cache_defrag (ctx);
173-
llama_kv_cache_update (ctx);
170+
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
171+
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
172+
//llama_kv_self_defrag (ctx);
173+
llama_kv_self_update (ctx);
174174

175-
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
175+
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
176176

177177
common_batch_clear(batch);
178178

@@ -198,12 +198,12 @@ int main(int argc, char ** argv) {
198198
if (n_discard > 0) {
199199
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
200200

201-
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
202-
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
203-
//llama_kv_cache_defrag (ctx);
204-
llama_kv_cache_update (ctx);
201+
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
202+
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
203+
//llama_kv_self_defrag (ctx);
204+
llama_kv_self_update (ctx);
205205

206-
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
206+
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
207207
}
208208
}
209209

0 commit comments

Comments
 (0)