4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
-
8
7
import argparse
9
-
10
8
from typing import Optional , Union
11
9
12
10
import torch
13
-
14
- from datasets import load_dataset
15
11
from executorch .examples .models .llama .export_llama_lib import (
12
+ _convert_args_to_config ,
13
+ _prepare_for_llama_export ,
14
+ build_args_parser as _build_args_parser ,
16
15
get_quantizer_and_quant_params ,
17
16
)
18
17
19
18
from executorch .extension .llm .export .builder import LLMEdgeManager
20
19
from lm_eval .evaluator import simple_evaluate
20
+ from omegaconf import DictConfig , OmegaConf
21
21
from pytorch_tokenizers import get_tokenizer
22
22
from pytorch_tokenizers .llama2c import Llama2cTokenizer as SentencePieceTokenizer
23
23
from pytorch_tokenizers .tiktoken import TiktokenTokenizer as Tiktoken
24
- from torch .nn import CrossEntropyLoss
25
- from tqdm import tqdm
26
24
27
25
from .evaluate .eager_eval import EagerEvalWrapper
28
26
29
- from .export_llama_lib import (
30
- _prepare_for_llama_export ,
31
- build_args_parser as _build_args_parser ,
32
- )
33
-
34
27
35
28
class GraphModuleEvalWrapper (EagerEvalWrapper ):
36
29
"""
@@ -163,7 +156,7 @@ def _model_call(self, inps):
163
156
164
157
def gen_eval_wrapper (
165
158
model_name : str ,
166
- args : argparse . ArgumentParser ,
159
+ config : DictConfig ,
167
160
):
168
161
"""
169
162
Generates a wrapper interface around the provided model and tokenizer for
@@ -172,17 +165,17 @@ def gen_eval_wrapper(
172
165
Returns:
173
166
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174
167
"""
175
- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore
168
+ tokenizer = get_tokenizer (config . model . tokenizer_path )
176
169
177
170
# ExecuTorch Binary Evaluation
178
- if (model := args . pte ) is not None : # pyre-ignore
179
- if (tokenizer_bin := args . tokenizer_bin ) is not None : # pyre-ignore
171
+ if (model := config . eval . pte ) is not None :
172
+ if (tokenizer_bin := config . eval . tokenizer_bin ) is not None :
180
173
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
181
174
return ETRunnerEvalWrapper (
182
175
model = model ,
183
176
tokenizer = tokenizer ,
184
177
tokenizer_bin = tokenizer_bin ,
185
- max_seq_length = args . max_seq_length , # pyre-ignore
178
+ max_seq_length = config . sequence . max_seq_length ,
186
179
)
187
180
188
181
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +184,12 @@ def gen_eval_wrapper(
191
184
tokenizer = tokenizer ,
192
185
# Exported model takes at most (max_seq_length - 1) tokens.
193
186
# Note that the eager model takes at most max_seq_length tokens.
194
- max_seq_length = args .max_seq_length - 1 ,
187
+ max_seq_length = config . sequence .max_seq_length - 1 ,
195
188
)
196
189
197
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
190
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (config )
198
191
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
192
+ manager : LLMEdgeManager = _prepare_for_llama_export (config )
200
193
201
194
if len (quantizers ) != 0 :
202
195
manager = manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +201,9 @@ def gen_eval_wrapper(
208
201
return GraphModuleEvalWrapper (
209
202
model = model ,
210
203
tokenizer = tokenizer ,
211
- max_seq_length = args .max_seq_length ,
212
- use_kv_cache = args . use_kv_cache , # pyre-ignore
213
- enable_dynamic_shape = args . enable_dynamic_shape , # pyre-ignore
204
+ max_seq_length = config . sequence .max_seq_length ,
205
+ use_kv_cache = config . kv_cache . use_kv_cache ,
206
+ enable_dynamic_shape = config . misc . enable_dynamic_shape ,
214
207
)
215
208
else :
216
209
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -228,18 +221,94 @@ def gen_eval_wrapper(
228
221
# that is not available in this eval_llama. We save the checkpoint
229
222
# here for consistency with eval_llama. The accuracy results we
230
223
# get from eval_llama can be used as a reference to other evaluations.
231
- if args . output_eager_checkpoint_file is not None : # pyre-ignore
232
- torch .save (model , args .output_eager_checkpoint_file )
224
+ if config . eval . output_eager_checkpoint_file is not None :
225
+ torch .save (model , config . eval .output_eager_checkpoint_file )
233
226
234
227
return EagerEvalWrapper (
235
228
model = model ,
236
229
tokenizer = tokenizer ,
237
- max_seq_length = args .max_seq_length ,
238
- use_kv_cache = args .use_kv_cache ,
230
+ max_seq_length = config .sequence .max_seq_length ,
231
+ use_kv_cache = config .kv_cache .use_kv_cache ,
232
+ )
233
+
234
+
235
+ def eval_llama (
236
+ model_name : str ,
237
+ config : DictConfig ,
238
+ ) -> None :
239
+ # Generate the eval wrapper
240
+ eval_wrapper = gen_eval_wrapper (model_name , config )
241
+
242
+ # Needed for loading mmlu dataset.
243
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
244
+ if config .eval .tasks and "mmlu" in config .eval .tasks :
245
+ import datasets
246
+
247
+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
248
+
249
+ # Evaluate the model
250
+ tasks = (
251
+ None if config .eval .tasks is None else OmegaConf .to_container (config .eval .tasks )
252
+ )
253
+ with torch .no_grad ():
254
+ eval_results = simple_evaluate (
255
+ model = eval_wrapper ,
256
+ tasks = tasks ,
257
+ num_fewshot = config .eval .num_fewshot ,
258
+ limit = config .eval .limit ,
239
259
)
240
260
261
+ for task , res in eval_results ["results" ].items ():
262
+ print (f"{ task } : { res } " )
263
+
264
+
265
+ def eval_llama_with_attention_sink (
266
+ model_name : str ,
267
+ config : DictConfig ,
268
+ ) -> None :
269
+ # Generate the eval wrapper
270
+ eval_wrapper = gen_eval_wrapper (model_name , config )
271
+
272
+ # Needed for loading mmlu dataset.
273
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
274
+ if config .eval .tasks and "mmlu" in config .eval .tasks :
275
+ import datasets
276
+
277
+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
278
+
279
+ # Evaluate the model
280
+ with torch .no_grad ():
281
+ eval_results = simple_evaluate (
282
+ model = eval_wrapper ,
283
+ tasks = OmegaConf .to_container (config .eval .tasks ),
284
+ num_fewshot = config .eval .num_fewshot ,
285
+ limit = config .eval .limit ,
286
+ )
287
+
288
+ for task , res in eval_results ["results" ].items ():
289
+ print (f"{ task } : { res } " )
290
+
291
+
292
+ def _convert_cli_to_config_format (args ) -> DictConfig :
293
+ """Convert CLI arguments to config format."""
294
+ # First convert common args using the shared function
295
+ config = _convert_args_to_config (args )
296
+
297
+ # Add evaluation-specific settings
298
+ config .eval = OmegaConf .create ()
299
+ config .eval .tasks = args .tasks
300
+ config .eval .limit = args .limit
301
+ config .eval .num_fewshot = args .num_fewshot
302
+ config .eval .pte = args .pte
303
+ config .eval .tokenizer_bin = args .tokenizer_bin
304
+ config .eval .output_eager_checkpoint_file = args .output_eager_checkpoint_file
305
+ config .eval .attention_sink_eval_tokens = args .attention_sink_eval_tokens
306
+
307
+ return config
308
+
241
309
242
310
def build_args_parser () -> argparse .ArgumentParser :
311
+ """Build argument parser for evaluation, extending the export parser with eval-specific args."""
243
312
# Start with arg parser from export_llama_lib
244
313
parser = _build_args_parser ()
245
314
@@ -286,92 +355,7 @@ def build_args_parser() -> argparse.ArgumentParser:
286
355
help = "Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint." ,
287
356
)
288
357
289
- # Set of parameters secpific to AttentionSink.
358
+ # Set of parameters specific to AttentionSink.
290
359
parser .add_argument ("--attention_sink_eval_tokens" , type = int , default = 0 )
291
360
292
361
return parser
293
-
294
-
295
- def eval_llama (
296
- model_name : str ,
297
- args : argparse .ArgumentParser ,
298
- ) -> None :
299
- # Generate the eval wrapper
300
- eval_wrapper = gen_eval_wrapper (model_name , args )
301
-
302
- # Needed for loading mmlu dataset.
303
- # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304
- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305
- if args .tasks and "mmlu" in args .tasks :
306
- import datasets
307
-
308
- datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
309
-
310
- # Evaluate the model
311
- with torch .no_grad ():
312
- eval_results = simple_evaluate (
313
- model = eval_wrapper ,
314
- tasks = args .tasks ,
315
- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316
- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
317
- )
318
-
319
- for task , res in eval_results ["results" ].items ():
320
- print (f"{ task } : { res } " )
321
-
322
-
323
- def eval_llama_with_attention_sink (model_name : str , args : argparse .ArgumentParser ):
324
- """
325
- Evaluate the model's perplexity when AttentionSink is enabled.
326
-
327
- This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328
- """
329
- assert args .use_attention_sink is not None # pyre-ignore [16]
330
- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
331
- attention_sink_params = args .use_attention_sink .split ("," )
332
- assert len (attention_sink_params ) == 3
333
- sink_size = int (attention_sink_params [0 ])
334
- window_size = int (attention_sink_params [1 ])
335
-
336
- assert args .max_seq_length == sink_size + window_size # pyre-ignore [16]
337
-
338
- device = "cuda" if torch .cuda .is_available () else "cpu"
339
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
340
- model = manager .model .eval ().to (device = device )
341
- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore [16]
342
-
343
- eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344
-
345
- nlls = []
346
- loss_fn = CrossEntropyLoss (reduction = "none" )
347
- progress_bar = tqdm (total = args .attention_sink_eval_tokens )
348
- input_pos = 0
349
- while input_pos < args .attention_sink_eval_tokens :
350
- for text in eval_data ["text" ]: # pyre-ignore [16]
351
- tokens = tokenizer .encode (text , bos = False , eos = False )
352
- if len (tokens ) <= 0 :
353
- continue
354
- with torch .no_grad ():
355
- num_tokens = min (
356
- len (tokens ) - 1 , args .attention_sink_eval_tokens - input_pos
357
- )
358
- logits = model (
359
- torch .tensor (
360
- [tokens [:num_tokens ]], dtype = torch .int64 , device = device
361
- ),
362
- torch .tensor ([input_pos ], dtype = torch .int64 , device = device ),
363
- ).squeeze (dim = 0 )
364
- neg_log_likelihood = loss_fn (
365
- logits ,
366
- torch .tensor (
367
- [tokens [1 : num_tokens + 1 ]], dtype = torch .int64 , device = device
368
- ).view (- 1 ),
369
- )
370
- nlls .append (neg_log_likelihood )
371
- input_pos += num_tokens
372
- progress_bar .update (num_tokens )
373
- if input_pos >= args .attention_sink_eval_tokens :
374
- break
375
- ppl = torch .exp (torch .cat (nlls ).mean ())
376
- print (f"Perplexity: { ppl .item ()} " )
377
- return ppl .item ()
0 commit comments