@@ -8436,74 +8436,33 @@ static enum ggml_status llama_graph_compute(
8436
8436
return status;
8437
8437
}
8438
8438
8439
- // decode a batch of tokens by evaluating the transformer
8440
- // in case of unsuccessful decoding (error or warning),
8441
- // the kv_cache state will be returned to its original state
8442
- // (for non-recurrent models) or cleaned (for recurrent models)
8443
- //
8444
- // - lctx: llama context
8445
- // - batch: batch to evaluate
8446
- //
8447
- // return 0 on success
8448
- // return positive int on warning
8449
- // return negative int on error
8450
- //
8451
- static int llama_decode_impl (
8452
- llama_context & lctx,
8453
- llama_batch inp_batch) {
8454
-
8455
- lctx.is_encoding = false ;
8456
-
8457
- if (inp_batch.n_tokens == 0 ) {
8458
- LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8459
- return -1 ;
8460
- }
8461
-
8462
- // temporary allocate memory for the input batch if needed
8463
- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8464
-
8465
- const llama_batch & batch = batch_allocr.batch ;
8466
- const uint32_t n_tokens_all = batch.n_tokens ;
8467
-
8439
+ static int llama_prepare_sbatch (
8440
+ llama_context & lctx,
8441
+ const llama_batch & batch,
8442
+ uint32_t & n_outputs) {
8468
8443
const auto & model = lctx.model ;
8469
- const auto & vocab = model.vocab ;
8470
8444
const auto & hparams = model.hparams ;
8471
8445
const auto & cparams = lctx.cparams ;
8472
8446
8473
- GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
8447
+ const uint32_t n_tokens_all = batch.n_tokens ;
8448
+ const int64_t n_embd = hparams.n_embd ;
8449
+
8450
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8451
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8474
8452
8453
+ GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
8475
8454
if (batch.token ) {
8476
8455
for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
8477
- if (batch.token [i] < 0 || ( uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
8456
+ if (batch.token [i] < 0 || uint32_t ( batch.token [i]) >= model.vocab .n_tokens ()) {
8478
8457
LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
8479
8458
return -1 ;
8480
8459
}
8481
8460
}
8482
8461
}
8483
-
8484
8462
GGML_ASSERT (n_tokens_all <= cparams.n_batch );
8485
-
8486
8463
GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
8487
8464
8488
- if (lctx.t_compute_start_us == 0 ) {
8489
- lctx.t_compute_start_us = ggml_time_us ();
8490
- }
8491
8465
lctx.n_queued_tokens += n_tokens_all;
8492
-
8493
- auto & kv_self = lctx.kv_self ;
8494
- llama_kv_slot_restorer kv_slot_restorer (kv_self);
8495
-
8496
- const int64_t n_embd = hparams.n_embd ;
8497
- const int64_t n_vocab = vocab.n_tokens ();
8498
-
8499
- uint32_t n_outputs = 0 ;
8500
- uint32_t n_outputs_prev = 0 ;
8501
-
8502
- const auto n_ubatch = cparams.n_ubatch ;
8503
-
8504
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8505
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8506
-
8507
8466
lctx.embd_seq .clear ();
8508
8467
8509
8468
// count outputs
@@ -8519,7 +8478,7 @@ static int llama_decode_impl(
8519
8478
}
8520
8479
8521
8480
lctx.sbatch .from_batch (batch, n_embd,
8522
- /* simple_split */ !kv_self.recurrent ,
8481
+ /* simple_split */ !lctx. kv_self .recurrent ,
8523
8482
/* logits_all */ n_outputs == n_tokens_all);
8524
8483
8525
8484
// reserve output buffer
@@ -8528,70 +8487,148 @@ static int llama_decode_impl(
8528
8487
return -2 ;
8529
8488
};
8530
8489
8531
- while (lctx.sbatch .n_tokens > 0 ) {
8532
- llama_ubatch ubatch;
8533
- if (kv_self.recurrent ) {
8534
- if (embd_pooled) {
8535
- // Pooled embeddings cannot be split across ubatches (yet)
8536
- ubatch = lctx.sbatch .split_seq (n_ubatch);
8537
- } else {
8538
- // recurrent model architectures are easier to implement
8539
- // with equal-length sequences
8540
- ubatch = lctx.sbatch .split_equal (n_ubatch);
8541
- }
8490
+ return 0 ;
8491
+ }
8492
+
8493
+ static int llama_prepare_ubatch (
8494
+ llama_context & lctx,
8495
+ llama_kv_slot_restorer & kv_slot_restorer,
8496
+ llama_ubatch & ubatch,
8497
+ const uint32_t n_outputs,
8498
+ const uint32_t n_tokens_all) {
8499
+ GGML_ASSERT (lctx.sbatch .n_tokens > 0 );
8500
+
8501
+ auto & kv_self = lctx.kv_self ;
8502
+ const auto & cparams = lctx.cparams ;
8503
+ const auto & hparams = lctx.model .hparams ;
8504
+
8505
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8506
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8507
+
8508
+ if (lctx.kv_self .recurrent ) {
8509
+ if (embd_pooled) {
8510
+ // Pooled embeddings cannot be split across ubatches (yet)
8511
+ ubatch = lctx.sbatch .split_seq (cparams.n_ubatch );
8542
8512
} else {
8543
- ubatch = lctx.sbatch .split_simple (n_ubatch);
8513
+ // recurrent model architectures are easier to implement
8514
+ // with equal-length sequences
8515
+ ubatch = lctx.sbatch .split_equal (cparams.n_ubatch );
8544
8516
}
8545
- const uint32_t n_tokens = ubatch.n_tokens ;
8517
+ } else {
8518
+ ubatch = lctx.sbatch .split_simple (cparams.n_ubatch );
8519
+ }
8546
8520
8547
- // count the outputs in this u_batch
8548
- {
8549
- int32_t n_outputs_new = 0 ;
8521
+ // count the outputs in this u_batch
8522
+ {
8523
+ int32_t n_outputs_new = 0 ;
8550
8524
8551
- if (n_outputs == n_tokens_all) {
8552
- n_outputs_new = n_tokens;
8553
- } else {
8554
- GGML_ASSERT (ubatch.output );
8555
- for (uint32_t i = 0 ; i < n_tokens; i++) {
8556
- n_outputs_new += (int32_t ) (ubatch.output [i] != 0 );
8557
- }
8525
+ if (n_outputs == n_tokens_all) {
8526
+ n_outputs_new = ubatch.n_tokens ;
8527
+ } else {
8528
+ GGML_ASSERT (ubatch.output );
8529
+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; i++) {
8530
+ n_outputs_new += int32_t (ubatch.output [i] != 0 );
8558
8531
}
8532
+ }
8533
+
8534
+ // needs to happen before the graph is built
8535
+ lctx.n_outputs = n_outputs_new;
8536
+ }
8537
+
8538
+ // non-causal masks do not use the KV cache
8539
+ if (hparams.causal_attn ) {
8540
+ llama_kv_cache_update (&lctx);
8559
8541
8560
- // needs to happen before the graph is built
8561
- lctx.n_outputs = n_outputs_new;
8542
+ // if we have enough unused cells before the current head ->
8543
+ // better to start searching from the beginning of the cache, hoping to fill it
8544
+ if (kv_self.head > kv_self.used + 2 *ubatch.n_tokens ) {
8545
+ kv_self.head = 0 ;
8562
8546
}
8563
8547
8564
- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8565
- ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8548
+ const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8549
+ if (!slot) {
8550
+ return 1 ;
8551
+ }
8552
+ kv_slot_restorer.save (slot);
8566
8553
8567
- GGML_ASSERT (n_threads > 0 );
8554
+ if (!kv_self.recurrent ) {
8555
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
8556
+ // after enough generations, the benefit from this heuristic disappears
8557
+ // if we start defragmenting the cache, the benefit from this will be more important
8558
+ const uint32_t pad = llama_kv_cache_get_padding (cparams);
8559
+ kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8560
+ // kv_self.n = llama_kv_cache_cell_max(kv_self);
8561
+ }
8562
+ }
8568
8563
8569
- // non-causal masks do not use the KV cache
8570
- if (hparams.causal_attn ) {
8571
- llama_kv_cache_update (&lctx);
8564
+ return 0 ;
8565
+ }
8572
8566
8573
- // if we have enough unused cells before the current head ->
8574
- // better to start searching from the beginning of the cache, hoping to fill it
8575
- if (kv_self.head > kv_self.used + 2 *n_tokens) {
8576
- kv_self.head = 0 ;
8577
- }
8567
+ // decode a batch of tokens by evaluating the transformer
8568
+ // in case of unsuccessful decoding (error or warning),
8569
+ // the kv_cache state will be returned to its original state
8570
+ // (for non-recurrent models) or cleaned (for recurrent models)
8571
+ //
8572
+ // - lctx: llama context
8573
+ // - inp_batch: batch to evaluate
8574
+ //
8575
+ // return 0 on success
8576
+ // return positive int on warning
8577
+ // return negative int on error
8578
+ //
8579
+ static int llama_decode_impl (
8580
+ llama_context & lctx,
8581
+ llama_batch inp_batch) {
8578
8582
8579
- const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8580
- if (!slot) {
8581
- return 1 ;
8582
- }
8583
- kv_slot_restorer.save (slot);
8583
+ lctx.is_encoding = false ;
8584
8584
8585
- if (!kv_self.recurrent ) {
8586
- // a heuristic, to avoid attending the full cache if it is not yet utilized
8587
- // after enough generations, the benefit from this heuristic disappears
8588
- // if we start defragmenting the cache, the benefit from this will be more important
8589
- const uint32_t pad = llama_kv_cache_get_padding (cparams);
8590
- kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8591
- // kv_self.n = llama_kv_cache_cell_max(kv_self);
8585
+ if (inp_batch.n_tokens == 0 ) {
8586
+ LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8587
+ return -1 ;
8588
+ }
8589
+
8590
+ // temporarily allocate memory for the input batch if needed
8591
+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8592
+ const llama_batch & batch = batch_allocr.batch ;
8593
+
8594
+ const auto & model = lctx.model ;
8595
+ const auto & vocab = model.vocab ;
8596
+ const auto & hparams = model.hparams ;
8597
+ const auto & cparams = lctx.cparams ;
8598
+
8599
+ if (lctx.t_compute_start_us == 0 ) {
8600
+ lctx.t_compute_start_us = ggml_time_us ();
8601
+ }
8602
+ auto & kv_self = lctx.kv_self ;
8603
+ llama_kv_slot_restorer kv_slot_restorer (kv_self);
8604
+
8605
+ const int64_t n_embd = hparams.n_embd ;
8606
+ const int64_t n_vocab = vocab.n_tokens ();
8607
+
8608
+ uint32_t n_outputs = 0 ;
8609
+ uint32_t n_outputs_prev = 0 ;
8610
+
8611
+ {
8612
+ const int ret = llama_prepare_sbatch (lctx, batch, n_outputs);
8613
+ if (ret != 0 ) {
8614
+ return ret;
8615
+ }
8616
+ }
8617
+
8618
+ while (lctx.sbatch .n_tokens > 0 ) {
8619
+ llama_ubatch ubatch;
8620
+ {
8621
+ const int ret = llama_prepare_ubatch (lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens );
8622
+ if (ret != 0 ) {
8623
+ return ret;
8592
8624
}
8593
8625
}
8594
8626
8627
+ const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8628
+ ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8629
+
8630
+ GGML_ASSERT (n_threads > 0 );
8631
+
8595
8632
// printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
8596
8633
8597
8634
ggml_backend_sched_reset (lctx.sched .get ());
@@ -8644,7 +8681,7 @@ static int llama_decode_impl(
8644
8681
8645
8682
// update the kv ring buffer
8646
8683
{
8647
- kv_self.head += n_tokens;
8684
+ kv_self.head += ubatch. n_tokens ;
8648
8685
8649
8686
// Ensure kv cache head points to a valid index.
8650
8687
if (kv_self.head >= kv_self.size ) {
0 commit comments