Skip to content

Commit 5ecf7b7

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Introduce GenerationConfig (#10228)
Summary: Started to implement #9341 Started to fix #8495 This PR introduces `GenerationConfig` which contains the configs that can be changed across different invocations of `generate()`. For example, `temperature` is moved out from the runner constructor for it's not tied to the runner instance but instead should be adjustable every time we call `generate()`. Similarly we put `echo` and `warming` into the config. We also allow both `seq_len` and `max_new_tokens` to be passed through the config and we determine the value of `max_new_tokens` based on these 2 config values, pte file metadata as well as the number of prompt tokens. Reviewed By: iseeyuan Differential Revision: D73091676
1 parent 9154002 commit 5ecf7b7

File tree

18 files changed

+360
-114
lines changed

18 files changed

+360
-114
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#import <executorch/examples/models/llama/runner/runner.h>
1313
#import <executorch/examples/models/llava/runner/llava_runner.h>
1414

15+
using executorch::extension::llm::GenerationConfig;
1516
using executorch::extension::llm::Image;
1617
using executorch::runtime::Error;
1718

@@ -61,8 +62,11 @@ - (BOOL)generate:(NSString*)prompt
6162
sequenceLength:(NSInteger)seq_len
6263
withTokenCallback:(nullable void (^)(NSString*))callback
6364
error:(NSError**)error {
65+
const GenerationConfig config{
66+
.seq_len = static_cast<int32_t>(seq_len)
67+
};
6468
const auto status = _runner->generate(
65-
prompt.UTF8String, seq_len, [callback](const std::string& token) {
69+
prompt.UTF8String, config, [callback](const std::string& token) {
6670
callback(@(token.c_str()));
6771
});
6872
if (status != Error::Ok) {

examples/mediatek/executor_runner/mtk_llama_runner.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,9 @@ bool MTKLlamaRunner::is_loaded() const {
8080

8181
Error MTKLlamaRunner::generate(
8282
const std::string& prompt,
83-
int32_t seq_len,
83+
executorch::extension::llm::GenerationConfig config,
8484
std::function<void(const std::string&)> token_callback,
85-
std::function<void(const Stats&)> stats_callback,
86-
bool echo,
87-
bool warming) {
85+
std::function<void(const Stats&)> stats_callback) {
8886
if (!is_loaded()) {
8987
ET_CHECK_OK_OR_RETURN_ERROR(load());
9088
}

examples/mediatek/executor_runner/mtk_llama_runner.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@ class MTKLlamaRunner : public executorch::extension::llm::IRunner {
4343
Error load();
4444
Error generate(
4545
const std::string& prompt,
46-
int32_t seq_len = 128,
46+
executorch::extension::llm::GenerationConfig config,
4747
std::function<void(const std::string&)> token_callback = {},
48-
std::function<void(const Stats&)> stats_callback = {},
49-
bool echo = true,
50-
bool warming = false);
48+
std::function<void(const Stats&)> stats_callback = {});
5149
void stop();
5250

5351
LlamaModelOptions get_model_options();

examples/models/llama/main.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int32_t main(int32_t argc, char** argv) {
5353

5454
const char* prompt = FLAGS_prompt.c_str();
5555

56-
double temperature = FLAGS_temperature;
56+
float temperature = FLAGS_temperature;
5757

5858
int32_t seq_len = FLAGS_seq_len;
5959

@@ -73,13 +73,15 @@ int32_t main(int32_t argc, char** argv) {
7373
}
7474
#endif
7575
// create llama runner
76-
example::Runner runner(model_path, tokenizer_path, temperature);
76+
example::Runner runner(model_path, tokenizer_path);
7777

7878
if (warmup) {
79-
runner.warmup(prompt, seq_len);
79+
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
8080
}
8181
// generate
82-
runner.generate(prompt, seq_len);
82+
executorch::extension::llm::GenerationConfig config{
83+
.seq_len = seq_len, .temperature = temperature};
84+
runner.generate(prompt, config);
8385

8486
return 0;
8587
}

examples/models/llama/runner/runner.cpp

+38-42
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,11 @@ static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
4141
Runner::Runner(
4242
const std::string& model_path,
4343
const std::string& tokenizer_path,
44-
const float temperature,
4544
std::optional<const std::string> data_path)
4645
// NOTE: we observed ~2x loading performance increase on iPhone 15
4746
// and a ~5% improvement on Galaxy S22 by switching to
4847
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
49-
: temperature_(temperature),
50-
tokenizer_path_(tokenizer_path),
48+
: tokenizer_path_(tokenizer_path),
5149
metadata_({
5250
{kEnableDynamicShape, false},
5351
{kMaxSeqLen, 128},
@@ -134,10 +132,7 @@ Error Runner::load() {
134132
}
135133
}
136134
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
137-
module_.get(),
138-
metadata_.at(kUseKVCache),
139-
metadata_.at(kVocabSize),
140-
temperature_);
135+
module_.get(), metadata_.at(kUseKVCache), metadata_.at(kVocabSize));
141136
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
142137
text_decoder_runner_.get(),
143138
metadata_.at(kUseKVCache),
@@ -164,11 +159,9 @@ Error Runner::load() {
164159

165160
Error Runner::generate(
166161
const std::string& prompt,
167-
int32_t seq_len,
162+
const ::executorch::extension::llm::GenerationConfig& config,
168163
std::function<void(const std::string&)> token_callback,
169-
std::function<void(const llm::Stats&)> stats_callback,
170-
bool echo,
171-
bool warmup) {
164+
std::function<void(const llm::Stats&)> stats_callback) {
172165
// Prepare the inputs.
173166
// Use ones-initialized inputs.
174167
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -178,19 +171,19 @@ Error Runner::generate(
178171
stats_.model_load_end_ms = llm::time_in_ms();
179172
}
180173

181-
if (warmup) {
174+
if (config.warming) {
182175
ET_LOG(Info, "Doing a warmup run...");
183176
}
184177

185178
RUNNER_ET_LOG(
186-
warmup,
179+
config.warming,
187180
"RSS after loading model: %f MiB (0 if unsupported)",
188181
llm::get_rss_bytes() / 1024.0 / 1024.0);
189182

190183
// Wrap the token_callback with print function
191184
std::function<void(const std::string&)> wrapped_callback =
192-
[token_callback, warmup](const std::string& piece) {
193-
if (!warmup) {
185+
[token_callback, config](const std::string& piece) {
186+
if (!config.warming) {
194187
llm::safe_printf(piece.c_str());
195188
fflush(stdout);
196189
}
@@ -204,11 +197,6 @@ Error Runner::generate(
204197
stats_.inference_start_ms = llm::time_in_ms();
205198
shouldStop_ = false;
206199

207-
// Set the sequence length to the max seq length if not provided
208-
seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxContextLen))
209-
? seq_len
210-
: metadata_.at(kMaxContextLen);
211-
212200
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
213201
prompt,
214202
/* bos */ 0,
@@ -225,21 +213,22 @@ Error Runner::generate(
225213
ET_CHECK_MSG(
226214
num_prompt_tokens < metadata_.at(kMaxContextLen),
227215
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
228-
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
216+
", Max seq length exceeded - please increase max seq len value in your export script",
229217
num_prompt_tokens,
230218
metadata_.at(kMaxContextLen));
231-
ET_CHECK_MSG(
232-
num_prompt_tokens < seq_len,
233-
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
234-
num_prompt_tokens,
235-
seq_len);
219+
220+
// Determine max_new_tokens using the GenerationConfig's resolve method
221+
int max_new_tokens = config.resolve_max_new_tokens(
222+
metadata_.at(kMaxContextLen), num_prompt_tokens);
223+
224+
ET_LOG(Info, "Max new tokens resolved: %d", max_new_tokens);
236225

237226
// Prefill first
238227
// Here feed all tokens to the model and get the next predicted token
239228
// after the prompt. After that we will enter generate loop.
240229

241230
// print prompts
242-
if (echo) {
231+
if (config.echo) {
243232
wrapped_callback(prompt);
244233
}
245234
int64_t pos = 0;
@@ -253,32 +242,38 @@ Error Runner::generate(
253242
wrapped_callback(
254243
ET_UNWRAP_TOKENIZER(tokenizer_->decode(cur_token, cur_token)));
255244
RUNNER_ET_LOG(
256-
warmup,
245+
config.warming,
257246
"RSS after prompt prefill: %f MiB (0 if unsupported)",
258247
llm::get_rss_bytes() / 1024.0 / 1024.0);
259248

260249
// start the main loop
261250
prompt_tokens.push_back(cur_token);
251+
252+
// Generate max_new_tokens - 1 because prefill already generated 1 token.
262253
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
263-
prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
254+
prompt_tokens,
255+
num_prompt_tokens,
256+
max_new_tokens - 1,
257+
config.temperature,
258+
wrapped_callback));
264259

265260
stats_.inference_end_ms = llm::time_in_ms();
266-
if (!warmup) {
261+
if (!config.warming) {
267262
printf("\n");
268263
}
269264
RUNNER_ET_LOG(
270-
warmup,
265+
config.warming,
271266
"RSS after finishing text generation: %f MiB (0 if unsupported)",
272267
llm::get_rss_bytes() / 1024.0 / 1024.0);
273268

274-
if (num_prompt_tokens + num_generated_tokens == seq_len) {
275-
RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len);
269+
if (num_generated_tokens == max_new_tokens) {
270+
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
276271
}
277272

278273
stats_.num_prompt_tokens = num_prompt_tokens;
279274
stats_.num_generated_tokens = num_generated_tokens;
280275

281-
if (warmup) {
276+
if (config.warming) {
282277
ET_LOG(Info, "Warmup run finished!");
283278
} else {
284279
// Do not print report during warmup
@@ -291,14 +286,15 @@ Error Runner::generate(
291286
return Error::Ok;
292287
}
293288

294-
Error Runner::warmup(const std::string& prompt, int32_t seq_len) {
295-
Error err = generate(
296-
prompt,
297-
seq_len,
298-
/*token_callback=*/nullptr,
299-
/*stats_callbak=*/nullptr,
300-
/*echo=*/false,
301-
/*warmup=*/true);
289+
Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
290+
// Create a GenerationConfig for warmup
291+
llm::GenerationConfig config{
292+
.max_new_tokens = max_new_tokens, .echo = false, .warming = true};
293+
294+
// Call generate with the warmup config
295+
Error err = generate(prompt, config);
296+
297+
// Reset stats after warmup
302298
stats_.reset();
303299
return err;
304300
}

examples/models/llama/runner/runner.h

+3-7
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,22 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
3333
explicit Runner(
3434
const std::string& model_path,
3535
const std::string& tokenizer_path,
36-
const float temperature = 0.8f,
3736
std::optional<const std::string> data_path = std::nullopt);
3837

3938
bool is_loaded() const;
4039
::executorch::runtime::Error load();
4140
::executorch::runtime::Error generate(
4241
const std::string& prompt,
43-
int32_t seq_len = 128,
42+
const ::executorch::extension::llm::GenerationConfig& config,
4443
std::function<void(const std::string&)> token_callback = {},
4544
std::function<void(const ::executorch::extension::llm::Stats&)>
46-
stats_callback = {},
47-
bool echo = true,
48-
bool warming = false);
45+
stats_callback = {});
4946
::executorch::runtime::Error warmup(
5047
const std::string& prompt,
51-
int32_t seq_len = 128);
48+
int32_t max_new_tokens);
5249
void stop();
5350

5451
private:
55-
float temperature_;
5652
bool shouldStop_{false};
5753

5854
// model

examples/models/llava/runner/llava_runner.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Error LlavaRunner::load() {
4848

4949
// Load the text decoder runner
5050
text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
51-
module_.get(), tokenizer_->vocab_size(), temperature_);
51+
module_.get(), tokenizer_->vocab_size());
5252
text_decoder_runner_->load();
5353

5454
// Load the text prefiller
@@ -117,7 +117,11 @@ Error LlavaRunner::generate_from_pos(
117117

118118
// Generate tokens
119119
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
120-
{prefill_next_token}, start_pos, seq_len, token_callback));
120+
/*tokens=*/{prefill_next_token},
121+
/*start_pos=*/start_pos,
122+
/*max_new_tokens=*/seq_len - start_pos + 1,
123+
/*temperature=*/temperature_,
124+
/*token_callback=*/token_callback));
121125

122126
// Bookkeeping
123127
stats_.num_generated_tokens = num_generated_tokens;

examples/models/llava/runner/llava_text_decoder_runner.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
1919
public:
2020
LlavaTextDecoderRunner(
2121
executorch::extension::Module* module,
22-
int32_t vocab_size,
23-
float temperature)
24-
: TextDecoderRunner(module, true, vocab_size, temperature){};
22+
int32_t vocab_size)
23+
: TextDecoderRunner(module, true, vocab_size){};
2524

2625
inline executorch::runtime::Result<executorch::aten::Tensor> step(
2726
executorch::extension::TensorPtr& tokens,

extension/android/jni/jni_layer_llama.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
219219
[callback](const llm::Stats& result) { callback->onStats(result); },
220220
echo);
221221
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
222+
executorch::extension::llm::GenerationConfig config{
223+
.seq_len = seq_len,
224+
.echo = echo,
225+
};
222226
runner_->generate(
223227
prompt->toStdString(),
224-
seq_len,
228+
config,
225229
[callback](std::string result) { callback->onResult(result); },
226-
[callback](const llm::Stats& result) { callback->onStats(result); },
227-
echo);
230+
[callback](const llm::Stats& result) { callback->onStats(result); });
228231
}
229232
return 0;
230233
}

extension/llm/runner/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,7 @@ target_include_directories(
5353
extension_llm_runner INTERFACE ${_common_include_directories}
5454
${EXECUTORCH_ROOT}/extension/llm/tokenizers/include
5555
)
56+
57+
if(BUILD_TESTING)
58+
add_subdirectory(test)
59+
endif()

0 commit comments

Comments
 (0)