@@ -41,13 +41,11 @@ static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
41
41
Runner::Runner (
42
42
const std::string& model_path,
43
43
const std::string& tokenizer_path,
44
- const float temperature,
45
44
std::optional<const std::string> data_path)
46
45
// NOTE: we observed ~2x loading performance increase on iPhone 15
47
46
// and a ~5% improvement on Galaxy S22 by switching to
48
47
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
49
- : temperature_(temperature),
50
- tokenizer_path_ (tokenizer_path),
48
+ : tokenizer_path_(tokenizer_path),
51
49
metadata_ ({
52
50
{kEnableDynamicShape , false },
53
51
{kMaxSeqLen , 128 },
@@ -68,6 +66,17 @@ Runner::Runner(
68
66
tokenizer_path.c_str ());
69
67
}
70
68
69
+ [[deprecated(
70
+ " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
71
+ Runner::Runner (
72
+ const std::string& model_path,
73
+ const std::string& tokenizer_path,
74
+ const float temperature,
75
+ std::optional<const std::string> data_path)
76
+ : Runner(model_path, tokenizer_path, data_path) {
77
+ temperature_ = temperature;
78
+ }
79
+
71
80
bool Runner::is_loaded () const {
72
81
return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
73
82
text_prefiller_ && text_token_generator_;
@@ -133,11 +142,9 @@ Error Runner::load() {
133
142
ET_LOG (Info, " eos_id = %" PRId64, value);
134
143
}
135
144
}
145
+ // @lint-ignore CLANGTIDY facebook-hte-Deprecated
136
146
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
137
- module_.get (),
138
- metadata_.at (kUseKVCache ),
139
- metadata_.at (kVocabSize ),
140
- temperature_);
147
+ module_.get (), metadata_.at (kUseKVCache ));
141
148
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
142
149
text_decoder_runner_.get (),
143
150
metadata_.at (kUseKVCache ),
@@ -164,11 +171,9 @@ Error Runner::load() {
164
171
165
172
Error Runner::generate (
166
173
const std::string& prompt,
167
- int32_t seq_len ,
174
+ const ::executorch::extension::llm::GenerationConfig& config ,
168
175
std::function<void (const std::string&)> token_callback,
169
- std::function<void(const llm::Stats&)> stats_callback,
170
- bool echo,
171
- bool warmup) {
176
+ std::function<void(const llm::Stats&)> stats_callback) {
172
177
// Prepare the inputs.
173
178
// Use ones-initialized inputs.
174
179
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
@@ -178,19 +183,19 @@ Error Runner::generate(
178
183
stats_.model_load_end_ms = llm::time_in_ms ();
179
184
}
180
185
181
- if (warmup ) {
186
+ if (config. warming ) {
182
187
ET_LOG (Info, " Doing a warmup run..." );
183
188
}
184
189
185
190
RUNNER_ET_LOG (
186
- warmup ,
191
+ config. warming ,
187
192
" RSS after loading model: %f MiB (0 if unsupported)" ,
188
193
llm::get_rss_bytes () / 1024.0 / 1024.0 );
189
194
190
195
// Wrap the token_callback with print function
191
196
std::function<void (const std::string&)> wrapped_callback =
192
- [token_callback, warmup ](const std::string& piece) {
193
- if (!warmup ) {
197
+ [token_callback, config ](const std::string& piece) {
198
+ if (!config. warming ) {
194
199
llm::safe_printf (piece.c_str ());
195
200
fflush (stdout);
196
201
}
@@ -204,11 +209,6 @@ Error Runner::generate(
204
209
stats_.inference_start_ms = llm::time_in_ms ();
205
210
shouldStop_ = false ;
206
211
207
- // Set the sequence length to the max seq length if not provided
208
- seq_len = (seq_len > 0 && seq_len <= metadata_.at (kMaxContextLen ))
209
- ? seq_len
210
- : metadata_.at (kMaxContextLen );
211
-
212
212
::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
213
213
prompt,
214
214
/* bos */ 0 ,
@@ -225,21 +225,22 @@ Error Runner::generate(
225
225
ET_CHECK_MSG (
226
226
num_prompt_tokens < metadata_.at (kMaxContextLen ),
227
227
" num_prompt_tokens %d >= max_seq_len_ %" PRId64
228
- " , Max seq length exceeded - please increase max seq len value in .../llama2/model.py " ,
228
+ " , Max seq length exceeded - please increase max seq len value in your export script " ,
229
229
num_prompt_tokens,
230
230
metadata_.at (kMaxContextLen ));
231
- ET_CHECK_MSG (
232
- num_prompt_tokens < seq_len,
233
- " num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()" ,
234
- num_prompt_tokens,
235
- seq_len);
231
+
232
+ // Determine max_new_tokens using the GenerationConfig's resolve method
233
+ int max_new_tokens = config.resolve_max_new_tokens (
234
+ metadata_.at (kMaxContextLen ), num_prompt_tokens);
235
+
236
+ ET_LOG (Info, " Max new tokens resolved: %d" , max_new_tokens);
236
237
237
238
// Prefill first
238
239
// Here feed all tokens to the model and get the next predicted token
239
240
// after the prompt. After that we will enter generate loop.
240
241
241
242
// print prompts
242
- if (echo) {
243
+ if (config. echo ) {
243
244
wrapped_callback (prompt);
244
245
}
245
246
int64_t pos = 0 ;
@@ -253,32 +254,38 @@ Error Runner::generate(
253
254
wrapped_callback (
254
255
ET_UNWRAP_TOKENIZER (tokenizer_->decode (cur_token, cur_token)));
255
256
RUNNER_ET_LOG (
256
- warmup ,
257
+ config. warming ,
257
258
" RSS after prompt prefill: %f MiB (0 if unsupported)" ,
258
259
llm::get_rss_bytes () / 1024.0 / 1024.0 );
259
260
260
261
// start the main loop
261
262
prompt_tokens.push_back (cur_token);
263
+
264
+ // Generate max_new_tokens - 1 because prefill already generated 1 token.
262
265
int64_t num_generated_tokens = ET_UNWRAP (text_token_generator_->generate (
263
- prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
266
+ prompt_tokens,
267
+ num_prompt_tokens,
268
+ max_new_tokens - 1 ,
269
+ temperature_ == -1 .0f ? config.temperature : temperature_,
270
+ wrapped_callback));
264
271
265
272
stats_.inference_end_ms = llm::time_in_ms ();
266
- if (!warmup ) {
273
+ if (!config. warming ) {
267
274
printf (" \n " );
268
275
}
269
276
RUNNER_ET_LOG (
270
- warmup ,
277
+ config. warming ,
271
278
" RSS after finishing text generation: %f MiB (0 if unsupported)" ,
272
279
llm::get_rss_bytes () / 1024.0 / 1024.0 );
273
280
274
- if (num_prompt_tokens + num_generated_tokens == seq_len ) {
275
- RUNNER_ET_LOG (warmup , " Sequence length ( %i tokens) reached!" , seq_len );
281
+ if (num_generated_tokens == max_new_tokens ) {
282
+ RUNNER_ET_LOG (config. warming , " Max new tokens %i reached!" , max_new_tokens );
276
283
}
277
284
278
285
stats_.num_prompt_tokens = num_prompt_tokens;
279
286
stats_.num_generated_tokens = num_generated_tokens;
280
287
281
- if (warmup ) {
288
+ if (config. warming ) {
282
289
ET_LOG (Info, " Warmup run finished!" );
283
290
} else {
284
291
// Do not print report during warmup
@@ -291,14 +298,15 @@ Error Runner::generate(
291
298
return Error::Ok;
292
299
}
293
300
294
- Error Runner::warmup (const std::string& prompt, int32_t seq_len) {
295
- Error err = generate (
296
- prompt,
297
- seq_len,
298
- /* token_callback=*/ nullptr ,
299
- /* stats_callbak=*/ nullptr ,
300
- /* echo=*/ false ,
301
- /* warmup=*/ true );
301
+ Error Runner::warmup (const std::string& prompt, int32_t max_new_tokens) {
302
+ // Create a GenerationConfig for warmup
303
+ llm::GenerationConfig config{
304
+ .echo = false , .max_new_tokens = max_new_tokens, .warming = true };
305
+
306
+ // Call generate with the warmup config
307
+ Error err = generate (prompt, config);
308
+
309
+ // Reset stats after warmup
302
310
stats_.reset ();
303
311
return err;
304
312
}
0 commit comments