@@ -248,7 +248,7 @@ static void process_logits(
248
248
}
249
249
}
250
250
251
- static bool compute_imatrix (llama_context * ctx, const gpt_params & params) {
251
+ static bool compute_imatrix (llama_context * ctx, const gpt_params & params, bool compute_ppl ) {
252
252
253
253
const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
254
254
const int n_ctx = llama_n_ctx (ctx);
@@ -269,10 +269,12 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
269
269
}
270
270
271
271
std::vector<float > logit_history;
272
- logit_history.resize (tokens.size ());
273
-
274
272
std::vector<float > prob_history;
275
- prob_history.resize (tokens.size ());
273
+
274
+ if (compute_ppl) {
275
+ logit_history.resize (tokens.size ());
276
+ prob_history.resize (tokens.size ());
277
+ }
276
278
277
279
const int n_chunk_max = tokens.size () / n_ctx;
278
280
@@ -288,12 +290,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
288
290
289
291
std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
290
292
293
+ const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
294
+
295
+ std::vector<float > logits;
296
+ if (compute_ppl && num_batches > 1 ) {
297
+ logits.reserve ((size_t )n_ctx * n_vocab);
298
+ }
299
+
291
300
for (int i = 0 ; i < n_chunk; ++i) {
292
301
const int start = i * n_ctx;
293
302
const int end = start + n_ctx;
294
303
295
- const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
296
-
297
304
std::vector<float > logits;
298
305
299
306
const auto t_start = std::chrono::high_resolution_clock::now ();
@@ -321,8 +328,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
321
328
// restore the original token in case it was set to BOS
322
329
tokens[batch_start] = token_org;
323
330
324
- const auto * batch_logits = llama_get_logits (ctx);
325
- logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
331
+ if (compute_ppl && num_batches > 1 ) {
332
+ const auto * batch_logits = llama_get_logits (ctx);
333
+ logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
334
+ }
326
335
}
327
336
328
337
const auto t_end = std::chrono::high_resolution_clock::now ();
@@ -338,25 +347,32 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
338
347
fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
339
348
}
340
349
341
- const int first = n_ctx/2 ;
342
- process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
343
- workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
344
- count += n_ctx - first - 1 ;
350
+ if (compute_ppl) {
351
+ const int first = n_ctx/2 ;
352
+ const auto all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
353
+ process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
354
+ workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
355
+ count += n_ctx - first - 1 ;
356
+
357
+ printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
358
+ fflush (stdout);
345
359
346
- printf ( " [%d]%.4lf, " , i + 1 , std::exp (nll / count) );
347
- fflush (stdout);
360
+ logits. clear ( );
361
+ }
348
362
}
349
363
printf (" \n " );
350
364
351
- nll2 /= count;
352
- nll /= count;
353
- const double ppl = exp (nll);
354
- nll2 -= nll * nll;
355
- if (nll2 > 0 ) {
356
- nll2 = sqrt (nll2/(count-1 ));
357
- printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2*ppl);
358
- } else {
359
- printf (" Unexpected negative standard deviation of log(prob)\n " );
365
+ if (compute_ppl) {
366
+ nll2 /= count;
367
+ nll /= count;
368
+ const double ppl = exp (nll);
369
+ nll2 -= nll * nll;
370
+ if (nll2 > 0 ) {
371
+ nll2 = sqrt (nll2/(count-1 ));
372
+ printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2*ppl);
373
+ } else {
374
+ printf (" Unexpected negative standard deviation of log(prob)\n " );
375
+ }
360
376
}
361
377
362
378
return true ;
@@ -365,6 +381,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
365
381
int main (int argc, char ** argv) {
366
382
367
383
StatParams sparams;
384
+ bool compute_ppl = true ;
368
385
std::vector<char *> args;
369
386
args.push_back (argv[0 ]);
370
387
int iarg = 1 ;
@@ -381,12 +398,19 @@ int main(int argc, char ** argv) {
381
398
}
382
399
else if (arg == " --verbosity" ) {
383
400
sparams.verbosity = std::stoi (argv[++iarg]);
401
+ } else if (arg == " --no-ppl" ) {
402
+ compute_ppl = false ;
384
403
} else {
385
404
args.push_back (argv[iarg]);
386
405
}
387
406
}
388
407
if (iarg < argc) {
389
- args.push_back (argv[iarg]);
408
+ std::string arg{argv[iarg]};
409
+ if (arg == " --no-ppl" ) {
410
+ compute_ppl = false ;
411
+ } else {
412
+ args.push_back (argv[iarg]);
413
+ }
390
414
}
391
415
392
416
gpt_params params;
@@ -448,7 +472,7 @@ int main(int argc, char ** argv) {
448
472
fprintf (stderr, " %s\n " , get_system_info (params).c_str ());
449
473
}
450
474
451
- bool OK = compute_imatrix (ctx, params);
475
+ bool OK = compute_imatrix (ctx, params, compute_ppl );
452
476
if (!OK) {
453
477
return 1 ;
454
478
}
0 commit comments