4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
-
8
7
import argparse
9
-
10
8
from typing import Optional , Union
11
9
12
10
import torch
13
-
14
- from datasets import load_dataset
15
11
from executorch .examples .models .llama .export_llama_lib import (
12
+ _convert_args_to_config ,
13
+ _prepare_for_llama_export ,
14
+ build_args_parser as _build_args_parser ,
16
15
get_quantizer_and_quant_params ,
17
16
)
18
17
from executorch .examples .models .llama .tokenizer .tiktoken import Tokenizer as Tiktoken
19
-
20
18
from executorch .extension .llm .export .builder import LLMEdgeManager
21
19
from executorch .extension .llm .tokenizer .tokenizer import (
22
20
Tokenizer as SentencePieceTokenizer ,
23
21
)
24
22
from executorch .extension .llm .tokenizer .utils import get_tokenizer
25
23
from lm_eval .evaluator import simple_evaluate
26
- from torch .nn import CrossEntropyLoss
27
- from tqdm import tqdm
24
+ from omegaconf import DictConfig , OmegaConf
28
25
29
26
from .evaluate .eager_eval import EagerEvalWrapper
30
27
31
- from .export_llama_lib import (
32
- _prepare_for_llama_export ,
33
- build_args_parser as _build_args_parser ,
34
- )
35
-
36
28
37
29
class GraphModuleEvalWrapper (EagerEvalWrapper ):
38
30
"""
@@ -165,7 +157,7 @@ def _model_call(self, inps):
165
157
166
158
def gen_eval_wrapper (
167
159
model_name : str ,
168
- args : argparse . ArgumentParser ,
160
+ config : DictConfig ,
169
161
):
170
162
"""
171
163
Generates a wrapper interface around the provided model and tokenizer for
@@ -174,17 +166,17 @@ def gen_eval_wrapper(
174
166
Returns:
175
167
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
176
168
"""
177
- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore
169
+ tokenizer = get_tokenizer (config . model . tokenizer_path )
178
170
179
171
# ExecuTorch Binary Evaluation
180
- if (model := args . pte ) is not None : # pyre-ignore
181
- if (tokenizer_bin := args . tokenizer_bin ) is not None : # pyre-ignore
172
+ if (model := config . eval . pte ) is not None :
173
+ if (tokenizer_bin := config . eval . tokenizer_bin ) is not None :
182
174
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
183
175
return ETRunnerEvalWrapper (
184
176
model = model ,
185
177
tokenizer = tokenizer ,
186
178
tokenizer_bin = tokenizer_bin ,
187
- max_seq_length = args . max_seq_length , # pyre-ignore
179
+ max_seq_length = config . sequence . max_seq_length ,
188
180
)
189
181
190
182
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -193,12 +185,12 @@ def gen_eval_wrapper(
193
185
tokenizer = tokenizer ,
194
186
# Exported model takes at most (max_seq_length - 1) tokens.
195
187
# Note that the eager model takes at most max_seq_length tokens.
196
- max_seq_length = args .max_seq_length - 1 ,
188
+ max_seq_length = config . sequence .max_seq_length - 1 ,
197
189
)
198
190
199
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
191
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (config )
200
192
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
201
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
193
+ manager : LLMEdgeManager = _prepare_for_llama_export (config )
202
194
203
195
if len (quantizers ) != 0 :
204
196
manager = manager .export ().pt2e_quantize (quantizers )
@@ -210,9 +202,9 @@ def gen_eval_wrapper(
210
202
return GraphModuleEvalWrapper (
211
203
model = model ,
212
204
tokenizer = tokenizer ,
213
- max_seq_length = args .max_seq_length ,
214
- use_kv_cache = args .use_kv_cache , # pyre-ignore
215
- enable_dynamic_shape = args .enable_dynamic_shape , # pyre-ignore
205
+ max_seq_length = config . sequence .max_seq_length ,
206
+ use_kv_cache = config . kv_cache .use_kv_cache , # pyre-ignore
207
+ enable_dynamic_shape = config . misc .enable_dynamic_shape , # pyre-ignore
216
208
)
217
209
else :
218
210
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -230,18 +222,94 @@ def gen_eval_wrapper(
230
222
# that is not available in this eval_llama. We save the checkpoint
231
223
# here for consistency with eval_llama. The accuracy results we
232
224
# get from eval_llama can be used as a reference to other evaluations.
233
- if args .output_eager_checkpoint_file is not None : # pyre-ignore
234
- torch .save (model , args .output_eager_checkpoint_file )
225
+ if config . eval .output_eager_checkpoint_file is not None : # pyre-ignore
226
+ torch .save (model , config . eval .output_eager_checkpoint_file )
235
227
236
228
return EagerEvalWrapper (
237
229
model = model ,
238
230
tokenizer = tokenizer ,
239
- max_seq_length = args .max_seq_length ,
240
- use_kv_cache = args .use_kv_cache ,
231
+ max_seq_length = config .sequence .max_seq_length ,
232
+ use_kv_cache = config .kv_cache .use_kv_cache ,
233
+ )
234
+
235
+
236
+ def eval_llama (
237
+ model_name : str ,
238
+ config : DictConfig ,
239
+ ) -> None :
240
+ # Generate the eval wrapper
241
+ eval_wrapper = gen_eval_wrapper (model_name , config )
242
+
243
+ # Needed for loading mmlu dataset.
244
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
245
+ if config .eval .tasks and "mmlu" in config .eval .tasks :
246
+ import datasets
247
+
248
+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
249
+
250
+ # Evaluate the model
251
+ tasks = (
252
+ None if config .eval .tasks is None else OmegaConf .to_container (config .eval .tasks )
253
+ )
254
+ with torch .no_grad ():
255
+ eval_results = simple_evaluate (
256
+ model = eval_wrapper ,
257
+ tasks = tasks ,
258
+ num_fewshot = config .eval .num_fewshot ,
259
+ limit = config .eval .limit ,
260
+ )
261
+
262
+ for task , res in eval_results ["results" ].items ():
263
+ print (f"{ task } : { res } " )
264
+
265
+
266
+ def eval_llama_with_attention_sink (
267
+ model_name : str ,
268
+ config : DictConfig ,
269
+ ) -> None :
270
+ # Generate the eval wrapper
271
+ eval_wrapper = gen_eval_wrapper (model_name , config )
272
+
273
+ # Needed for loading mmlu dataset.
274
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
275
+ if config .eval .tasks and "mmlu" in config .eval .tasks :
276
+ import datasets
277
+
278
+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
279
+
280
+ # Evaluate the model
281
+ with torch .no_grad ():
282
+ eval_results = simple_evaluate (
283
+ model = eval_wrapper ,
284
+ tasks = OmegaConf .to_container (config .eval .tasks ),
285
+ num_fewshot = config .eval .num_fewshot ,
286
+ limit = config .eval .limit ,
241
287
)
242
288
289
+ for task , res in eval_results ["results" ].items ():
290
+ print (f"{ task } : { res } " )
291
+
292
+
293
+ def _convert_cli_to_config_format (args ) -> DictConfig :
294
+ """Convert CLI arguments to config format."""
295
+ # First convert common args using the shared function
296
+ config = _convert_args_to_config (args )
297
+
298
+ # Add evaluation-specific settings
299
+ config .eval = OmegaConf .create ()
300
+ config .eval .tasks = args .tasks
301
+ config .eval .limit = args .limit
302
+ config .eval .num_fewshot = args .num_fewshot
303
+ config .eval .pte = args .pte
304
+ config .eval .tokenizer_bin = args .tokenizer_bin
305
+ config .eval .output_eager_checkpoint_file = args .output_eager_checkpoint_file
306
+ config .eval .attention_sink_eval_tokens = args .attention_sink_eval_tokens
307
+
308
+ return config
309
+
243
310
244
311
def build_args_parser () -> argparse .ArgumentParser :
312
+ """Build argument parser for evaluation, extending the export parser with eval-specific args."""
245
313
# Start with arg parser from export_llama_lib
246
314
parser = _build_args_parser ()
247
315
@@ -288,92 +356,7 @@ def build_args_parser() -> argparse.ArgumentParser:
288
356
help = "Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint." ,
289
357
)
290
358
291
- # Set of parameters secpific to AttentionSink.
359
+ # Set of parameters specific to AttentionSink.
292
360
parser .add_argument ("--attention_sink_eval_tokens" , type = int , default = 0 )
293
361
294
362
return parser
295
-
296
-
297
- def eval_llama (
298
- model_name : str ,
299
- args : argparse .ArgumentParser ,
300
- ) -> None :
301
- # Generate the eval wrapper
302
- eval_wrapper = gen_eval_wrapper (model_name , args )
303
-
304
- # Needed for loading mmlu dataset.
305
- # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
306
- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
307
- if args .tasks and "mmlu" in args .tasks :
308
- import datasets
309
-
310
- datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
311
-
312
- # Evaluate the model
313
- with torch .no_grad ():
314
- eval_results = simple_evaluate (
315
- model = eval_wrapper ,
316
- tasks = args .tasks ,
317
- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
318
- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
319
- )
320
-
321
- for task , res in eval_results ["results" ].items ():
322
- print (f"{ task } : { res } " )
323
-
324
-
325
- def eval_llama_with_attention_sink (model_name : str , args : argparse .ArgumentParser ):
326
- """
327
- Evaluate the model's perplexity when AttentionSink is enabled.
328
-
329
- This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
330
- """
331
- assert args .use_attention_sink is not None # pyre-ignore [16]
332
- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
333
- attention_sink_params = args .use_attention_sink .split ("," )
334
- assert len (attention_sink_params ) == 3
335
- sink_size = int (attention_sink_params [0 ])
336
- window_size = int (attention_sink_params [1 ])
337
-
338
- assert args .max_seq_length == sink_size + window_size # pyre-ignore [16]
339
-
340
- device = "cuda" if torch .cuda .is_available () else "cpu"
341
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
342
- model = manager .model .eval ().to (device = device )
343
- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore [16]
344
-
345
- eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
346
-
347
- nlls = []
348
- loss_fn = CrossEntropyLoss (reduction = "none" )
349
- progress_bar = tqdm (total = args .attention_sink_eval_tokens )
350
- input_pos = 0
351
- while input_pos < args .attention_sink_eval_tokens :
352
- for text in eval_data ["text" ]: # pyre-ignore [16]
353
- tokens = tokenizer .encode (text , bos = False , eos = False )
354
- if len (tokens ) <= 0 :
355
- continue
356
- with torch .no_grad ():
357
- num_tokens = min (
358
- len (tokens ) - 1 , args .attention_sink_eval_tokens - input_pos
359
- )
360
- logits = model (
361
- torch .tensor (
362
- [tokens [:num_tokens ]], dtype = torch .int64 , device = device
363
- ),
364
- torch .tensor ([input_pos ], dtype = torch .int64 , device = device ),
365
- ).squeeze (dim = 0 )
366
- neg_log_likelihood = loss_fn (
367
- logits ,
368
- torch .tensor (
369
- [tokens [1 : num_tokens + 1 ]], dtype = torch .int64 , device = device
370
- ).view (- 1 ),
371
- )
372
- nlls .append (neg_log_likelihood )
373
- input_pos += num_tokens
374
- progress_bar .update (num_tokens )
375
- if input_pos >= args .attention_sink_eval_tokens :
376
- break
377
- ppl = torch .exp (torch .cat (nlls ).mean ())
378
- print (f"Perplexity: { ppl .item ()} " )
379
- return ppl .item ()
0 commit comments