@@ -386,7 +386,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
386386
387387            auto  * gf = lctx.graph_init ();
388388
389-             auto  res = build_graph_shift (lctx, gf);
389+             auto  res = build_graph_shift (lctx. get_cparams (), lctx. get_ctx_compute () , gf);
390390
391391            ggml_backend_sched_alloc_graph (sched, gf);
392392
@@ -414,7 +414,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
414414
415415            auto  * gf = lctx.graph_init ();
416416
417-             auto  res = build_graph_defrag (lctx, gf);
417+             auto  res = build_graph_defrag (lctx. get_cparams (), lctx. get_ctx_compute () , gf);
418418
419419            ggml_backend_sched_alloc_graph (sched, gf);
420420
@@ -592,15 +592,13 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
592592}
593593
594594ggml_tensor * llama_kv_cache_unified::build_rope_shift (
595-         llama_context & lctx,
596-          ggml_context * ctx,
597-           ggml_tensor * cur,
598-           ggml_tensor * shift,
599-           ggml_tensor * factors,
600-                 float    freq_base,
601-                 float    freq_scale) const  {
602-     const  auto  & cparams  = lctx.get_cparams ();
603- 
595+         const  llama_cparams & cparams,
596+                ggml_context * ctx,
597+                 ggml_tensor * cur,
598+                 ggml_tensor * shift,
599+                 ggml_tensor * factors,
600+                       float    freq_base,
601+                       float    freq_scale) const  {
604602    const  auto  & n_ctx_orig = cparams.n_ctx_orig_yarn ;
605603
606604    const  auto  & yarn_ext_factor = cparams.yarn_ext_factor ;
@@ -662,14 +660,11 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
662660}
663661
664662llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift (
665-         llama_context & lctx,
666-         ggml_cgraph * gf) const  {
663+         const  llama_cparams & cparams,
664+                ggml_context * ctx,
665+                 ggml_cgraph * gf) const  {
667666    auto  res = std::make_unique<llm_graph_result>();
668667
669-     auto  * ctx = lctx.get_ctx_compute ();
670- 
671-     const  auto  & cparams = lctx.get_cparams ();
672- 
673668    const  auto  & n_layer = hparams.n_layer ;
674669
675670    const  auto  & n_embd_head_k = hparams.n_embd_head_k ;
@@ -704,7 +699,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
704699                ggml_row_size (k_l[il]->type , n_embd_k_gqa),
705700                0 );
706701
707-         ggml_tensor * cur = build_rope_shift (lctx , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l);
702+         ggml_tensor * cur = build_rope_shift (cparams , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l);
708703
709704        ggml_build_forward_expand (gf, cur);
710705    }
@@ -715,16 +710,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
715710}
716711
717712llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag (
718-         llama_context & lctx,
719-           ggml_cgraph * gf) const  {
713+         const  llama_cparams & cparams,
714+                ggml_context * ctx,
715+                 ggml_cgraph * gf) const  {
720716    auto  res = std::make_unique<llm_graph_result>();
721717
722-     auto  * ctx = lctx.get_ctx_compute ();
723- 
724718    const  auto  & ids = defrag_info.ids ;
725719
726-     const  auto  & cparams = lctx.get_cparams ();
727- 
728720#if  0 
729721    // CPU defrag
730722    //
0 commit comments