Skip to content

Commit 780d6fb

Browse files
committed
kv-cache : do not pass the full llama_context for kv graphs
ggml-ci
1 parent 54f2bd4 commit 780d6fb

File tree

2 files changed

+29
-35
lines changed

2 files changed

+29
-35
lines changed

src/llama-kv-cache.cpp

+16-24
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
386386

387387
auto * gf = lctx.graph_init();
388388

389-
auto res = build_graph_shift(lctx, gf);
389+
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
390390

391391
ggml_backend_sched_alloc_graph(sched, gf);
392392

@@ -414,7 +414,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
414414

415415
auto * gf = lctx.graph_init();
416416

417-
auto res = build_graph_defrag(lctx, gf);
417+
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
418418

419419
ggml_backend_sched_alloc_graph(sched, gf);
420420

@@ -592,15 +592,13 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
592592
}
593593

594594
ggml_tensor * llama_kv_cache_unified::build_rope_shift(
595-
llama_context & lctx,
596-
ggml_context * ctx,
597-
ggml_tensor * cur,
598-
ggml_tensor * shift,
599-
ggml_tensor * factors,
600-
float freq_base,
601-
float freq_scale) const {
602-
const auto & cparams = lctx.get_cparams();
603-
595+
const llama_cparams & cparams,
596+
ggml_context * ctx,
597+
ggml_tensor * cur,
598+
ggml_tensor * shift,
599+
ggml_tensor * factors,
600+
float freq_base,
601+
float freq_scale) const {
604602
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
605603

606604
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
@@ -662,14 +660,11 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
662660
}
663661

664662
llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
665-
llama_context & lctx,
666-
ggml_cgraph * gf) const {
663+
const llama_cparams & cparams,
664+
ggml_context * ctx,
665+
ggml_cgraph * gf) const {
667666
auto res = std::make_unique<llm_graph_result>();
668667

669-
auto * ctx = lctx.get_ctx_compute();
670-
671-
const auto & cparams = lctx.get_cparams();
672-
673668
const auto & n_layer = hparams.n_layer;
674669

675670
const auto & n_embd_head_k = hparams.n_embd_head_k;
@@ -704,7 +699,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
704699
ggml_row_size(k_l[il]->type, n_embd_k_gqa),
705700
0);
706701

707-
ggml_tensor * cur = build_rope_shift(lctx, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
702+
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
708703

709704
ggml_build_forward_expand(gf, cur);
710705
}
@@ -715,16 +710,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
715710
}
716711

717712
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
718-
llama_context & lctx,
719-
ggml_cgraph * gf) const {
713+
const llama_cparams & cparams,
714+
ggml_context * ctx,
715+
ggml_cgraph * gf) const {
720716
auto res = std::make_unique<llm_graph_result>();
721717

722-
auto * ctx = lctx.get_ctx_compute();
723-
724718
const auto & ids = defrag_info.ids;
725719

726-
const auto & cparams = lctx.get_cparams();
727-
728720
#if 0
729721
// CPU defrag
730722
//

src/llama-kv-cache.h

+13-11
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,23 @@ class llama_kv_cache_unified : public llama_kv_cache {
233233
size_t size_v_bytes() const;
234234

235235
ggml_tensor * build_rope_shift(
236-
llama_context & lctx,
237-
ggml_context * ctx,
238-
ggml_tensor * cur,
239-
ggml_tensor * shift,
240-
ggml_tensor * factors,
241-
float freq_base,
242-
float freq_scale) const;
236+
const llama_cparams & cparams,
237+
ggml_context * ctx,
238+
ggml_tensor * cur,
239+
ggml_tensor * shift,
240+
ggml_tensor * factors,
241+
float freq_base,
242+
float freq_scale) const;
243243

244244
llm_graph_result_ptr build_graph_shift(
245-
llama_context & lctx,
246-
ggml_cgraph * gf) const;
245+
const llama_cparams & cparams,
246+
ggml_context * ctx,
247+
ggml_cgraph * gf) const;
247248

248249
llm_graph_result_ptr build_graph_defrag(
249-
llama_context & lctx,
250-
ggml_cgraph * gf) const;
250+
const llama_cparams & cparams,
251+
ggml_context * ctx,
252+
ggml_cgraph * gf) const;
251253

252254
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
253255
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;

0 commit comments

Comments
 (0)