Skip to content

Commit 726c0fa

Browse files
ikawrakowKawrakow
andauthored
Slightly faster imatrix (ggml-org#5050)
* imatrix: speedup by avoiding unnecessary allocations and copies * imatrix: add --no-ppl option to skip PPL calculations altogether --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 942c010 commit 726c0fa

File tree

1 file changed

+49
-25
lines changed

1 file changed

+49
-25
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ static void process_logits(
248248
}
249249
}
250250

251-
static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
251+
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) {
252252

253253
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
254254
const int n_ctx = llama_n_ctx(ctx);
@@ -269,10 +269,12 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
269269
}
270270

271271
std::vector<float> logit_history;
272-
logit_history.resize(tokens.size());
273-
274272
std::vector<float> prob_history;
275-
prob_history.resize(tokens.size());
273+
274+
if (compute_ppl) {
275+
logit_history.resize(tokens.size());
276+
prob_history.resize(tokens.size());
277+
}
276278

277279
const int n_chunk_max = tokens.size() / n_ctx;
278280

@@ -288,12 +290,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
288290

289291
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
290292

293+
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
294+
295+
std::vector<float> logits;
296+
if (compute_ppl && num_batches > 1) {
297+
logits.reserve((size_t)n_ctx * n_vocab);
298+
}
299+
291300
for (int i = 0; i < n_chunk; ++i) {
292301
const int start = i * n_ctx;
293302
const int end = start + n_ctx;
294303

295-
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
296-
297304
std::vector<float> logits;
298305

299306
const auto t_start = std::chrono::high_resolution_clock::now();
@@ -321,8 +328,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
321328
// restore the original token in case it was set to BOS
322329
tokens[batch_start] = token_org;
323330

324-
const auto * batch_logits = llama_get_logits(ctx);
325-
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
331+
if (compute_ppl && num_batches > 1) {
332+
const auto * batch_logits = llama_get_logits(ctx);
333+
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
334+
}
326335
}
327336

328337
const auto t_end = std::chrono::high_resolution_clock::now();
@@ -338,25 +347,32 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
338347
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
339348
}
340349

341-
const int first = n_ctx/2;
342-
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
343-
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
344-
count += n_ctx - first - 1;
350+
if (compute_ppl) {
351+
const int first = n_ctx/2;
352+
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
353+
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
354+
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
355+
count += n_ctx - first - 1;
356+
357+
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
358+
fflush(stdout);
345359

346-
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
347-
fflush(stdout);
360+
logits.clear();
361+
}
348362
}
349363
printf("\n");
350364

351-
nll2 /= count;
352-
nll /= count;
353-
const double ppl = exp(nll);
354-
nll2 -= nll * nll;
355-
if (nll2 > 0) {
356-
nll2 = sqrt(nll2/(count-1));
357-
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
358-
} else {
359-
printf("Unexpected negative standard deviation of log(prob)\n");
365+
if (compute_ppl) {
366+
nll2 /= count;
367+
nll /= count;
368+
const double ppl = exp(nll);
369+
nll2 -= nll * nll;
370+
if (nll2 > 0) {
371+
nll2 = sqrt(nll2/(count-1));
372+
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
373+
} else {
374+
printf("Unexpected negative standard deviation of log(prob)\n");
375+
}
360376
}
361377

362378
return true;
@@ -365,6 +381,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
365381
int main(int argc, char ** argv) {
366382

367383
StatParams sparams;
384+
bool compute_ppl = true;
368385
std::vector<char*> args;
369386
args.push_back(argv[0]);
370387
int iarg = 1;
@@ -381,12 +398,19 @@ int main(int argc, char ** argv) {
381398
}
382399
else if (arg == "--verbosity") {
383400
sparams.verbosity = std::stoi(argv[++iarg]);
401+
} else if (arg == "--no-ppl") {
402+
compute_ppl = false;
384403
} else {
385404
args.push_back(argv[iarg]);
386405
}
387406
}
388407
if (iarg < argc) {
389-
args.push_back(argv[iarg]);
408+
std::string arg{argv[iarg]};
409+
if (arg == "--no-ppl") {
410+
compute_ppl = false;
411+
} else {
412+
args.push_back(argv[iarg]);
413+
}
390414
}
391415

392416
gpt_params params;
@@ -448,7 +472,7 @@ int main(int argc, char ** argv) {
448472
fprintf(stderr, "%s\n", get_system_info(params).c_str());
449473
}
450474

451-
bool OK = compute_imatrix(ctx, params);
475+
bool OK = compute_imatrix(ctx, params, compute_ppl);
452476
if (!OK) {
453477
return 1;
454478
}

0 commit comments

Comments
 (0)