Skip to content

Commit 07f8474

Browse files
JohannesGaesslerNeoZhangJianyu
authored andcommitted
llama: refactor llama_decode_impl (ggml-org#11381)
1 parent d050cba commit 07f8474

File tree

1 file changed

+139
-102
lines changed

1 file changed

+139
-102
lines changed

src/llama.cpp

+139-102
Original file line numberDiff line numberDiff line change
@@ -8436,74 +8436,33 @@ static enum ggml_status llama_graph_compute(
84368436
return status;
84378437
}
84388438

8439-
// decode a batch of tokens by evaluating the transformer
8440-
// in case of unsuccessful decoding (error or warning),
8441-
// the kv_cache state will be returned to its original state
8442-
// (for non-recurrent models) or cleaned (for recurrent models)
8443-
//
8444-
// - lctx: llama context
8445-
// - batch: batch to evaluate
8446-
//
8447-
// return 0 on success
8448-
// return positive int on warning
8449-
// return negative int on error
8450-
//
8451-
static int llama_decode_impl(
8452-
llama_context & lctx,
8453-
llama_batch inp_batch) {
8454-
8455-
lctx.is_encoding = false;
8456-
8457-
if (inp_batch.n_tokens == 0) {
8458-
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
8459-
return -1;
8460-
}
8461-
8462-
// temporary allocate memory for the input batch if needed
8463-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
8464-
8465-
const llama_batch & batch = batch_allocr.batch;
8466-
const uint32_t n_tokens_all = batch.n_tokens;
8467-
8439+
static int llama_prepare_sbatch(
8440+
llama_context & lctx,
8441+
const llama_batch & batch,
8442+
uint32_t & n_outputs) {
84688443
const auto & model = lctx.model;
8469-
const auto & vocab = model.vocab;
84708444
const auto & hparams = model.hparams;
84718445
const auto & cparams = lctx.cparams;
84728446

8473-
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
8447+
const uint32_t n_tokens_all = batch.n_tokens;
8448+
const int64_t n_embd = hparams.n_embd;
8449+
8450+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8451+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
84748452

8453+
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
84758454
if (batch.token) {
84768455
for (uint32_t i = 0; i < n_tokens_all; ++i) {
8477-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
8456+
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
84788457
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
84798458
return -1;
84808459
}
84818460
}
84828461
}
8483-
84848462
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
8485-
84868463
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
84878464

8488-
if (lctx.t_compute_start_us == 0) {
8489-
lctx.t_compute_start_us = ggml_time_us();
8490-
}
84918465
lctx.n_queued_tokens += n_tokens_all;
8492-
8493-
auto & kv_self = lctx.kv_self;
8494-
llama_kv_slot_restorer kv_slot_restorer(kv_self);
8495-
8496-
const int64_t n_embd = hparams.n_embd;
8497-
const int64_t n_vocab = vocab.n_tokens();
8498-
8499-
uint32_t n_outputs = 0;
8500-
uint32_t n_outputs_prev = 0;
8501-
8502-
const auto n_ubatch = cparams.n_ubatch;
8503-
8504-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8505-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8506-
85078466
lctx.embd_seq.clear();
85088467

85098468
// count outputs
@@ -8519,7 +8478,7 @@ static int llama_decode_impl(
85198478
}
85208479

85218480
lctx.sbatch.from_batch(batch, n_embd,
8522-
/* simple_split */ !kv_self.recurrent,
8481+
/* simple_split */ !lctx.kv_self.recurrent,
85238482
/* logits_all */ n_outputs == n_tokens_all);
85248483

85258484
// reserve output buffer
@@ -8528,70 +8487,148 @@ static int llama_decode_impl(
85288487
return -2;
85298488
};
85308489

8531-
while (lctx.sbatch.n_tokens > 0) {
8532-
llama_ubatch ubatch;
8533-
if (kv_self.recurrent) {
8534-
if (embd_pooled) {
8535-
// Pooled embeddings cannot be split across ubatches (yet)
8536-
ubatch = lctx.sbatch.split_seq(n_ubatch);
8537-
} else {
8538-
// recurrent model architectures are easier to implement
8539-
// with equal-length sequences
8540-
ubatch = lctx.sbatch.split_equal(n_ubatch);
8541-
}
8490+
return 0;
8491+
}
8492+
8493+
static int llama_prepare_ubatch(
8494+
llama_context & lctx,
8495+
llama_kv_slot_restorer & kv_slot_restorer,
8496+
llama_ubatch & ubatch,
8497+
const uint32_t n_outputs,
8498+
const uint32_t n_tokens_all) {
8499+
GGML_ASSERT(lctx.sbatch.n_tokens > 0);
8500+
8501+
auto & kv_self = lctx.kv_self;
8502+
const auto & cparams = lctx.cparams;
8503+
const auto & hparams = lctx.model.hparams;
8504+
8505+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8506+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8507+
8508+
if (lctx.kv_self.recurrent) {
8509+
if (embd_pooled) {
8510+
// Pooled embeddings cannot be split across ubatches (yet)
8511+
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
85428512
} else {
8543-
ubatch = lctx.sbatch.split_simple(n_ubatch);
8513+
// recurrent model architectures are easier to implement
8514+
// with equal-length sequences
8515+
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
85448516
}
8545-
const uint32_t n_tokens = ubatch.n_tokens;
8517+
} else {
8518+
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
8519+
}
85468520

8547-
// count the outputs in this u_batch
8548-
{
8549-
int32_t n_outputs_new = 0;
8521+
// count the outputs in this u_batch
8522+
{
8523+
int32_t n_outputs_new = 0;
85508524

8551-
if (n_outputs == n_tokens_all) {
8552-
n_outputs_new = n_tokens;
8553-
} else {
8554-
GGML_ASSERT(ubatch.output);
8555-
for (uint32_t i = 0; i < n_tokens; i++) {
8556-
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
8557-
}
8525+
if (n_outputs == n_tokens_all) {
8526+
n_outputs_new = ubatch.n_tokens;
8527+
} else {
8528+
GGML_ASSERT(ubatch.output);
8529+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
8530+
n_outputs_new += int32_t(ubatch.output[i] != 0);
85588531
}
8532+
}
8533+
8534+
// needs to happen before the graph is built
8535+
lctx.n_outputs = n_outputs_new;
8536+
}
8537+
8538+
// non-causal masks do not use the KV cache
8539+
if (hparams.causal_attn) {
8540+
llama_kv_cache_update(&lctx);
85598541

8560-
// needs to happen before the graph is built
8561-
lctx.n_outputs = n_outputs_new;
8542+
// if we have enough unused cells before the current head ->
8543+
// better to start searching from the beginning of the cache, hoping to fill it
8544+
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
8545+
kv_self.head = 0;
85628546
}
85638547

8564-
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
8565-
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
8548+
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
8549+
if (!slot) {
8550+
return 1;
8551+
}
8552+
kv_slot_restorer.save(slot);
85668553

8567-
GGML_ASSERT(n_threads > 0);
8554+
if (!kv_self.recurrent) {
8555+
// a heuristic, to avoid attending the full cache if it is not yet utilized
8556+
// after enough generations, the benefit from this heuristic disappears
8557+
// if we start defragmenting the cache, the benefit from this will be more important
8558+
const uint32_t pad = llama_kv_cache_get_padding(cparams);
8559+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
8560+
//kv_self.n = llama_kv_cache_cell_max(kv_self);
8561+
}
8562+
}
85688563

8569-
// non-causal masks do not use the KV cache
8570-
if (hparams.causal_attn) {
8571-
llama_kv_cache_update(&lctx);
8564+
return 0;
8565+
}
85728566

8573-
// if we have enough unused cells before the current head ->
8574-
// better to start searching from the beginning of the cache, hoping to fill it
8575-
if (kv_self.head > kv_self.used + 2*n_tokens) {
8576-
kv_self.head = 0;
8577-
}
8567+
// decode a batch of tokens by evaluating the transformer
8568+
// in case of unsuccessful decoding (error or warning),
8569+
// the kv_cache state will be returned to its original state
8570+
// (for non-recurrent models) or cleaned (for recurrent models)
8571+
//
8572+
// - lctx: llama context
8573+
// - inp_batch: batch to evaluate
8574+
//
8575+
// return 0 on success
8576+
// return positive int on warning
8577+
// return negative int on error
8578+
//
8579+
static int llama_decode_impl(
8580+
llama_context & lctx,
8581+
llama_batch inp_batch) {
85788582

8579-
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
8580-
if (!slot) {
8581-
return 1;
8582-
}
8583-
kv_slot_restorer.save(slot);
8583+
lctx.is_encoding = false;
85848584

8585-
if (!kv_self.recurrent) {
8586-
// a heuristic, to avoid attending the full cache if it is not yet utilized
8587-
// after enough generations, the benefit from this heuristic disappears
8588-
// if we start defragmenting the cache, the benefit from this will be more important
8589-
const uint32_t pad = llama_kv_cache_get_padding(cparams);
8590-
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
8591-
//kv_self.n = llama_kv_cache_cell_max(kv_self);
8585+
if (inp_batch.n_tokens == 0) {
8586+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
8587+
return -1;
8588+
}
8589+
8590+
// temporarily allocate memory for the input batch if needed
8591+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
8592+
const llama_batch & batch = batch_allocr.batch;
8593+
8594+
const auto & model = lctx.model;
8595+
const auto & vocab = model.vocab;
8596+
const auto & hparams = model.hparams;
8597+
const auto & cparams = lctx.cparams;
8598+
8599+
if (lctx.t_compute_start_us == 0) {
8600+
lctx.t_compute_start_us = ggml_time_us();
8601+
}
8602+
auto & kv_self = lctx.kv_self;
8603+
llama_kv_slot_restorer kv_slot_restorer(kv_self);
8604+
8605+
const int64_t n_embd = hparams.n_embd;
8606+
const int64_t n_vocab = vocab.n_tokens();
8607+
8608+
uint32_t n_outputs = 0;
8609+
uint32_t n_outputs_prev = 0;
8610+
8611+
{
8612+
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
8613+
if (ret != 0) {
8614+
return ret;
8615+
}
8616+
}
8617+
8618+
while (lctx.sbatch.n_tokens > 0) {
8619+
llama_ubatch ubatch;
8620+
{
8621+
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
8622+
if (ret != 0) {
8623+
return ret;
85928624
}
85938625
}
85948626

8627+
const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
8628+
ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
8629+
8630+
GGML_ASSERT(n_threads > 0);
8631+
85958632
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
85968633

85978634
ggml_backend_sched_reset(lctx.sched.get());
@@ -8644,7 +8681,7 @@ static int llama_decode_impl(
86448681

86458682
// update the kv ring buffer
86468683
{
8647-
kv_self.head += n_tokens;
8684+
kv_self.head += ubatch.n_tokens;
86488685

86498686
// Ensure kv cache head points to a valid index.
86508687
if (kv_self.head >= kv_self.size) {

0 commit comments

Comments
 (0)