Skip to content

Commit e89ba89

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Introduce GenerationConfig (#10228)
Summary: Started to implement #9341 Started to fix #8495 This PR introduces `GenerationConfig` which contains the configs that can be changed across different invocations of `generate()`. For example, `temperature` is moved out from the runner constructor for it's not tied to the runner instance but instead should be adjustable every time we call `generate()`. Similarly we put `echo` and `warming` into the config. We also allow both `seq_len` and `max_new_tokens` to be passed through the config and we determine the value of `max_new_tokens` based on these 2 config values, pte file metadata as well as the number of prompt tokens. Reviewed By: iseeyuan Differential Revision: D73091676
1 parent cb80092 commit e89ba89

File tree

21 files changed

+392
-126
lines changed

21 files changed

+392
-126
lines changed

CMakeLists.txt

+6-2
Original file line numberDiff line numberDiff line change
@@ -761,12 +761,16 @@ if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
761761
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor)
762762
endif()
763763

764+
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
765+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
766+
endif()
767+
764768
if(EXECUTORCH_BUILD_EXTENSION_LLM)
765769
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/tokenizers)
766770
endif()
767771

768-
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
769-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
772+
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
773+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
770774
endif()
771775

772776
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#import <executorch/examples/models/llama/runner/runner.h>
1313
#import <executorch/examples/models/llava/runner/llava_runner.h>
1414

15+
using executorch::extension::llm::GenerationConfig;
1516
using executorch::extension::llm::Image;
1617
using executorch::runtime::Error;
1718

@@ -61,8 +62,11 @@ - (BOOL)generate:(NSString*)prompt
6162
sequenceLength:(NSInteger)seq_len
6263
withTokenCallback:(nullable void (^)(NSString*))callback
6364
error:(NSError**)error {
65+
const GenerationConfig config{
66+
.seq_len = static_cast<int32_t>(seq_len)
67+
};
6468
const auto status = _runner->generate(
65-
prompt.UTF8String, seq_len, [callback](const std::string& token) {
69+
prompt.UTF8String, config, [callback](const std::string& token) {
6670
callback(@(token.c_str()));
6771
});
6872
if (status != Error::Ok) {

examples/mediatek/executor_runner/mtk_llama_runner.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,9 @@ bool MTKLlamaRunner::is_loaded() const {
8080

8181
Error MTKLlamaRunner::generate(
8282
const std::string& prompt,
83-
int32_t seq_len,
83+
executorch::extension::llm::GenerationConfig config,
8484
std::function<void(const std::string&)> token_callback,
85-
std::function<void(const Stats&)> stats_callback,
86-
bool echo,
87-
bool warming) {
85+
std::function<void(const Stats&)> stats_callback) {
8886
if (!is_loaded()) {
8987
ET_CHECK_OK_OR_RETURN_ERROR(load());
9088
}

examples/mediatek/executor_runner/mtk_llama_runner.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@ class MTKLlamaRunner : public executorch::extension::llm::IRunner {
4343
Error load();
4444
Error generate(
4545
const std::string& prompt,
46-
int32_t seq_len = 128,
46+
executorch::extension::llm::GenerationConfig config,
4747
std::function<void(const std::string&)> token_callback = {},
48-
std::function<void(const Stats&)> stats_callback = {},
49-
bool echo = true,
50-
bool warming = false);
48+
std::function<void(const Stats&)> stats_callback = {});
5149
void stop();
5250

5351
LlamaModelOptions get_model_options();

examples/models/llama/main.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int32_t main(int32_t argc, char** argv) {
5353

5454
const char* prompt = FLAGS_prompt.c_str();
5555

56-
double temperature = FLAGS_temperature;
56+
float temperature = FLAGS_temperature;
5757

5858
int32_t seq_len = FLAGS_seq_len;
5959

@@ -73,13 +73,18 @@ int32_t main(int32_t argc, char** argv) {
7373
}
7474
#endif
7575
// create llama runner
76-
example::Runner runner(model_path, tokenizer_path, temperature);
76+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
77+
example::Runner runner(model_path, tokenizer_path);
7778

7879
if (warmup) {
79-
runner.warmup(prompt, seq_len);
80+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
81+
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
8082
}
8183
// generate
82-
runner.generate(prompt, seq_len);
84+
executorch::extension::llm::GenerationConfig config{
85+
.seq_len = seq_len, .temperature = temperature};
86+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
87+
runner.generate(prompt, config);
8388

8489
return 0;
8590
}

examples/models/llama/runner/runner.cpp

+39-42
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,11 @@ static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
4141
Runner::Runner(
4242
const std::string& model_path,
4343
const std::string& tokenizer_path,
44-
const float temperature,
4544
std::optional<const std::string> data_path)
4645
// NOTE: we observed ~2x loading performance increase on iPhone 15
4746
// and a ~5% improvement on Galaxy S22 by switching to
4847
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
49-
: temperature_(temperature),
50-
tokenizer_path_(tokenizer_path),
48+
: tokenizer_path_(tokenizer_path),
5149
metadata_({
5250
{kEnableDynamicShape, false},
5351
{kMaxSeqLen, 128},
@@ -133,11 +131,9 @@ Error Runner::load() {
133131
ET_LOG(Info, "eos_id = %" PRId64, value);
134132
}
135133
}
134+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
136135
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
137-
module_.get(),
138-
metadata_.at(kUseKVCache),
139-
metadata_.at(kVocabSize),
140-
temperature_);
136+
module_.get(), metadata_.at(kUseKVCache));
141137
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
142138
text_decoder_runner_.get(),
143139
metadata_.at(kUseKVCache),
@@ -164,11 +160,9 @@ Error Runner::load() {
164160

165161
Error Runner::generate(
166162
const std::string& prompt,
167-
int32_t seq_len,
163+
const ::executorch::extension::llm::GenerationConfig& config,
168164
std::function<void(const std::string&)> token_callback,
169-
std::function<void(const llm::Stats&)> stats_callback,
170-
bool echo,
171-
bool warmup) {
165+
std::function<void(const llm::Stats&)> stats_callback) {
172166
// Prepare the inputs.
173167
// Use ones-initialized inputs.
174168
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -178,19 +172,19 @@ Error Runner::generate(
178172
stats_.model_load_end_ms = llm::time_in_ms();
179173
}
180174

181-
if (warmup) {
175+
if (config.warming) {
182176
ET_LOG(Info, "Doing a warmup run...");
183177
}
184178

185179
RUNNER_ET_LOG(
186-
warmup,
180+
config.warming,
187181
"RSS after loading model: %f MiB (0 if unsupported)",
188182
llm::get_rss_bytes() / 1024.0 / 1024.0);
189183

190184
// Wrap the token_callback with print function
191185
std::function<void(const std::string&)> wrapped_callback =
192-
[token_callback, warmup](const std::string& piece) {
193-
if (!warmup) {
186+
[token_callback, config](const std::string& piece) {
187+
if (!config.warming) {
194188
llm::safe_printf(piece.c_str());
195189
fflush(stdout);
196190
}
@@ -204,11 +198,6 @@ Error Runner::generate(
204198
stats_.inference_start_ms = llm::time_in_ms();
205199
shouldStop_ = false;
206200

207-
// Set the sequence length to the max seq length if not provided
208-
seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxContextLen))
209-
? seq_len
210-
: metadata_.at(kMaxContextLen);
211-
212201
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
213202
prompt,
214203
/* bos */ 0,
@@ -225,21 +214,22 @@ Error Runner::generate(
225214
ET_CHECK_MSG(
226215
num_prompt_tokens < metadata_.at(kMaxContextLen),
227216
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
228-
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
217+
", Max seq length exceeded - please increase max seq len value in your export script",
229218
num_prompt_tokens,
230219
metadata_.at(kMaxContextLen));
231-
ET_CHECK_MSG(
232-
num_prompt_tokens < seq_len,
233-
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
234-
num_prompt_tokens,
235-
seq_len);
220+
221+
// Determine max_new_tokens using the GenerationConfig's resolve method
222+
int max_new_tokens = config.resolve_max_new_tokens(
223+
metadata_.at(kMaxContextLen), num_prompt_tokens);
224+
225+
ET_LOG(Info, "Max new tokens resolved: %d", max_new_tokens);
236226

237227
// Prefill first
238228
// Here feed all tokens to the model and get the next predicted token
239229
// after the prompt. After that we will enter generate loop.
240230

241231
// print prompts
242-
if (echo) {
232+
if (config.echo) {
243233
wrapped_callback(prompt);
244234
}
245235
int64_t pos = 0;
@@ -253,32 +243,38 @@ Error Runner::generate(
253243
wrapped_callback(
254244
ET_UNWRAP_TOKENIZER(tokenizer_->decode(cur_token, cur_token)));
255245
RUNNER_ET_LOG(
256-
warmup,
246+
config.warming,
257247
"RSS after prompt prefill: %f MiB (0 if unsupported)",
258248
llm::get_rss_bytes() / 1024.0 / 1024.0);
259249

260250
// start the main loop
261251
prompt_tokens.push_back(cur_token);
252+
253+
// Generate max_new_tokens - 1 because prefill already generated 1 token.
262254
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
263-
prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
255+
prompt_tokens,
256+
num_prompt_tokens,
257+
max_new_tokens - 1,
258+
config.temperature,
259+
wrapped_callback));
264260

265261
stats_.inference_end_ms = llm::time_in_ms();
266-
if (!warmup) {
262+
if (!config.warming) {
267263
printf("\n");
268264
}
269265
RUNNER_ET_LOG(
270-
warmup,
266+
config.warming,
271267
"RSS after finishing text generation: %f MiB (0 if unsupported)",
272268
llm::get_rss_bytes() / 1024.0 / 1024.0);
273269

274-
if (num_prompt_tokens + num_generated_tokens == seq_len) {
275-
RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len);
270+
if (num_generated_tokens == max_new_tokens) {
271+
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
276272
}
277273

278274
stats_.num_prompt_tokens = num_prompt_tokens;
279275
stats_.num_generated_tokens = num_generated_tokens;
280276

281-
if (warmup) {
277+
if (config.warming) {
282278
ET_LOG(Info, "Warmup run finished!");
283279
} else {
284280
// Do not print report during warmup
@@ -291,14 +287,15 @@ Error Runner::generate(
291287
return Error::Ok;
292288
}
293289

294-
Error Runner::warmup(const std::string& prompt, int32_t seq_len) {
295-
Error err = generate(
296-
prompt,
297-
seq_len,
298-
/*token_callback=*/nullptr,
299-
/*stats_callbak=*/nullptr,
300-
/*echo=*/false,
301-
/*warmup=*/true);
290+
Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
291+
// Create a GenerationConfig for warmup
292+
llm::GenerationConfig config{
293+
.echo = false, .max_new_tokens = max_new_tokens, .warming = true};
294+
295+
// Call generate with the warmup config
296+
Error err = generate(prompt, config);
297+
298+
// Reset stats after warmup
302299
stats_.reset();
303300
return err;
304301
}

examples/models/llama/runner/runner.h

+3-7
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,22 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
3333
explicit Runner(
3434
const std::string& model_path,
3535
const std::string& tokenizer_path,
36-
const float temperature = 0.8f,
3736
std::optional<const std::string> data_path = std::nullopt);
3837

3938
bool is_loaded() const;
4039
::executorch::runtime::Error load();
4140
::executorch::runtime::Error generate(
4241
const std::string& prompt,
43-
int32_t seq_len = 128,
42+
const ::executorch::extension::llm::GenerationConfig& config,
4443
std::function<void(const std::string&)> token_callback = {},
4544
std::function<void(const ::executorch::extension::llm::Stats&)>
46-
stats_callback = {},
47-
bool echo = true,
48-
bool warming = false);
45+
stats_callback = {}) override;
4946
::executorch::runtime::Error warmup(
5047
const std::string& prompt,
51-
int32_t seq_len = 128);
48+
int32_t max_new_tokens);
5249
void stop();
5350

5451
private:
55-
float temperature_;
5652
bool shouldStop_{false};
5753

5854
// model

examples/models/llava/runner/llava_runner.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ Error LlavaRunner::load() {
4747
tokenizer_->load(tokenizer_path_);
4848

4949
// Load the text decoder runner
50-
text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
51-
module_.get(), tokenizer_->vocab_size(), temperature_);
50+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
51+
text_decoder_runner_ =
52+
std::make_unique<LlavaTextDecoderRunner>(module_.get());
53+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
5254
text_decoder_runner_->load();
5355

5456
// Load the text prefiller
@@ -117,7 +119,11 @@ Error LlavaRunner::generate_from_pos(
117119

118120
// Generate tokens
119121
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
120-
{prefill_next_token}, start_pos, seq_len, token_callback));
122+
/*tokens=*/{prefill_next_token},
123+
/*start_pos=*/start_pos,
124+
/*max_new_tokens=*/seq_len - start_pos + 1,
125+
/*temperature=*/temperature_,
126+
/*token_callback=*/token_callback));
121127

122128
// Bookkeeping
123129
stats_.num_generated_tokens = num_generated_tokens;

examples/models/llava/runner/llava_text_decoder_runner.h

+2-5
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@ namespace example {
1717
class ET_EXPERIMENTAL LlavaTextDecoderRunner
1818
: public executorch::extension::llm::TextDecoderRunner {
1919
public:
20-
LlavaTextDecoderRunner(
21-
executorch::extension::Module* module,
22-
int32_t vocab_size,
23-
float temperature)
24-
: TextDecoderRunner(module, true, vocab_size, temperature){};
20+
LlavaTextDecoderRunner(executorch::extension::Module* module)
21+
: TextDecoderRunner(module, true) {}
2522

2623
inline executorch::runtime::Result<executorch::aten::Tensor> step(
2724
executorch::extension::TensorPtr& tokens,

extension/android/jni/jni_layer_llama.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
219219
[callback](const llm::Stats& result) { callback->onStats(result); },
220220
echo);
221221
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
222+
executorch::extension::llm::GenerationConfig config{
223+
.echo = static_cast<bool>(echo),
224+
.seq_len = seq_len,
225+
};
222226
runner_->generate(
223227
prompt->toStdString(),
224-
seq_len,
228+
config,
225229
[callback](std::string result) { callback->onResult(result); },
226-
[callback](const llm::Stats& result) { callback->onStats(result); },
227-
echo);
230+
[callback](const llm::Stats& result) { callback->onStats(result); });
228231
}
229232
return 0;
230233
}

extension/benchmark/apple/Benchmark/Tests/LLaMA/LLaMATests.mm

+8-4
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,18 @@ @implementation LLaMATests
8585
[testCase measureWithMetrics:@[ tokensPerSecondMetric, [XCTClockMetric new], [XCTMemoryMetric new] ]
8686
block:^{
8787
tokensPerSecondMetric.tokenCount = 0;
88+
// Create a GenerationConfig object
89+
::executorch::extension::llm::GenerationConfig config{
90+
.max_new_tokens = 50,
91+
.warming = false,
92+
};
93+
8894
const auto status = runner->generate(
8995
"Once upon a time",
90-
50,
96+
config,
9197
[=](const std::string &token) {
9298
tokensPerSecondMetric.tokenCount++;
93-
},
94-
nullptr,
95-
false);
99+
});
96100
XCTAssertEqual(status, Error::Ok);
97101
}];
98102
},

extension/llm/runner/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,7 @@ target_include_directories(
5353
extension_llm_runner INTERFACE ${_common_include_directories}
5454
${EXECUTORCH_ROOT}/extension/llm/tokenizers/include
5555
)
56+
57+
if(BUILD_TESTING)
58+
add_subdirectory(test)
59+
endif()

0 commit comments

Comments
 (0)