@@ -439,6 +439,21 @@ int main(int argc, char ** argv) {
439
439
LOG_TEE (" sampling: \n %s\n " , llama_sampling_print (sparams).c_str ());
440
440
LOG_TEE (" sampling order: \n %s\n " , llama_sampling_order_print (sparams).c_str ());
441
441
LOG_TEE (" generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
442
+
443
+ // group-attention state
444
+ // number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
445
+ int ga_i = 0 ;
446
+
447
+ const int ga_n = params.grp_attn_n ;
448
+ const int ga_w = params.grp_attn_w ;
449
+
450
+ if (ga_n != 1 ) {
451
+ GGML_ASSERT (ga_n > 0 && " grp_attn_n must be positive" ); // NOLINT
452
+ GGML_ASSERT (ga_w % ga_n == 0 && " grp_attn_w must be a multiple of grp_attn_n" ); // NOLINT
453
+ // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT
454
+ // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
455
+ LOG_TEE (" self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n " , n_ctx_train, ga_n, ga_w);
456
+ }
442
457
LOG_TEE (" \n\n " );
443
458
444
459
if (params.interactive ) {
@@ -500,37 +515,61 @@ int main(int argc, char ** argv) {
500
515
fflush (stdout);
501
516
}
502
517
503
- // infinite text generation via context swapping
504
- // if we run out of context:
505
- // - take the n_keep first tokens from the original prompt (via n_past)
506
- // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
507
- if (n_past + (int ) embd.size () + std::max<int >(0 , guidance_offset) > n_ctx) {
508
- if (params.n_predict == -2 ) {
509
- LOG_TEE (" \n\n %s: context full and n_predict == -%d => stopping\n " , __func__, params.n_predict );
510
- break ;
511
- }
518
+ if (ga_n == 1 ) {
519
+ // infinite text generation via context shifting
520
+ // if we run out of context:
521
+ // - take the n_keep first tokens from the original prompt (via n_past)
522
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
523
+ if (n_past + (int ) embd.size () + std::max<int >(0 , guidance_offset) > n_ctx) {
524
+ if (params.n_predict == -2 ) {
525
+ LOG_TEE (" \n\n %s: context full and n_predict == -%d => stopping\n " , __func__, params.n_predict );
526
+ break ;
527
+ }
512
528
513
- const int n_left = n_past - params.n_keep - 1 ;
514
- const int n_discard = n_left/2 ;
529
+ const int n_left = n_past - params.n_keep - 1 ;
530
+ const int n_discard = n_left/2 ;
515
531
516
- LOG (" context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n " ,
517
- n_past, n_left, n_ctx, params.n_keep , n_discard);
532
+ LOG (" context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n " ,
533
+ n_past, n_left, n_ctx, params.n_keep , n_discard);
518
534
519
- llama_kv_cache_seq_rm (ctx, 0 , params.n_keep + 1 , params.n_keep + n_discard + 1 );
520
- llama_kv_cache_seq_shift (ctx, 0 , params.n_keep + 1 + n_discard, n_past, -n_discard);
535
+ llama_kv_cache_seq_rm (ctx, 0 , params.n_keep + 1 , params.n_keep + n_discard + 1 );
536
+ llama_kv_cache_seq_shift (ctx, 0 , params.n_keep + 1 + n_discard, n_past, -n_discard);
521
537
522
- n_past -= n_discard;
538
+ n_past -= n_discard;
523
539
524
- if (ctx_guidance) {
525
- n_past_guidance -= n_discard;
540
+ if (ctx_guidance) {
541
+ n_past_guidance -= n_discard;
542
+ }
543
+
544
+ LOG (" after swap: n_past = %d, n_past_guidance = %d\n " , n_past, n_past_guidance);
545
+
546
+ LOG (" embd: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd).c_str ());
547
+
548
+ LOG (" clear session path\n " );
549
+ path_session.clear ();
526
550
}
551
+ } else {
552
+ // context extension via Self-Extend
553
+ while (n_past >= ga_i + ga_w) {
554
+ const int ib = (ga_n*ga_i)/ga_w;
555
+ const int bd = (ga_w/ga_n)*(ga_n - 1 );
556
+ const int dd = (ga_w/ga_n) - ib*bd - ga_w;
527
557
528
- LOG (" after swap: n_past = %d, n_past_guidance = %d\n " , n_past, n_past_guidance);
558
+ LOG (" \n " );
559
+ LOG (" shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
560
+ LOG (" div: [%6d, %6d] / %6d -> [%6d, %6d]\n " , ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
561
+ LOG (" shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
529
562
530
- LOG (" embd: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd).c_str ());
563
+ llama_kv_cache_seq_shift (ctx, 0 , ga_i, n_past, ib*bd);
564
+ llama_kv_cache_seq_div (ctx, 0 , ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
565
+ llama_kv_cache_seq_shift (ctx, 0 , ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
531
566
532
- LOG (" clear session path\n " );
533
- path_session.clear ();
567
+ n_past -= bd;
568
+
569
+ ga_i += ga_w/ga_n;
570
+
571
+ LOG (" \n n_past_old = %d, n_past = %d, ga_i = %d\n\n " , n_past + bd, n_past, ga_i);
572
+ }
534
573
}
535
574
536
575
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
0 commit comments