44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7- 
87import  argparse 
9- 
108from  typing  import  Optional , Union 
119
1210import  torch 
13- 
14- from  datasets  import  load_dataset 
1511from  executorch .examples .models .llama .export_llama_lib  import  (
12+     _convert_args_to_config ,
13+     _prepare_for_llama_export ,
14+     build_args_parser  as  _build_args_parser ,
1615    get_quantizer_and_quant_params ,
1716)
1817
1918from  executorch .extension .llm .export .builder  import  LLMEdgeManager 
2019from  lm_eval .evaluator  import  simple_evaluate 
20+ from  omegaconf  import  DictConfig , OmegaConf 
2121from  pytorch_tokenizers  import  get_tokenizer 
2222from  pytorch_tokenizers .llama2c  import  Llama2cTokenizer  as  SentencePieceTokenizer 
2323from  pytorch_tokenizers .tiktoken  import  TiktokenTokenizer  as  Tiktoken 
24- from  torch .nn  import  CrossEntropyLoss 
25- from  tqdm  import  tqdm 
2624
2725from  .evaluate .eager_eval  import  EagerEvalWrapper 
2826
29- from  .export_llama_lib  import  (
30-     _prepare_for_llama_export ,
31-     build_args_parser  as  _build_args_parser ,
32- )
33- 
3427
3528class  GraphModuleEvalWrapper (EagerEvalWrapper ):
3629    """ 
@@ -163,7 +156,7 @@ def _model_call(self, inps):
163156
164157def  gen_eval_wrapper (
165158    model_name : str ,
166-     args :  argparse . ArgumentParser ,
159+     config :  DictConfig ,
167160):
168161    """ 
169162    Generates a wrapper interface around the provided model and tokenizer for 
@@ -172,17 +165,17 @@ def gen_eval_wrapper(
172165    Returns: 
173166        eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. 
174167    """ 
175-     tokenizer  =  get_tokenizer (args . tokenizer_path )   # pyre-ignore 
168+     tokenizer  =  get_tokenizer (config . model . tokenizer_path )
176169
177170    # ExecuTorch Binary Evaluation 
178-     if  (model  :=  args . pte ) is  not   None :   # pyre-ignore 
179-         if  (tokenizer_bin  :=  args . tokenizer_bin ) is  not   None :   # pyre-ignore 
171+     if  (model  :=  config . eval . pte ) is  not   None :
172+         if  (tokenizer_bin  :=  config . eval . tokenizer_bin ) is  not   None :
180173            # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime 
181174            return  ETRunnerEvalWrapper (
182175                model = model ,
183176                tokenizer = tokenizer ,
184177                tokenizer_bin = tokenizer_bin ,
185-                 max_seq_length = args . max_seq_length ,   # pyre-ignore 
178+                 max_seq_length = config . sequence . max_seq_length ,
186179            )
187180
188181        # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings 
@@ -191,12 +184,12 @@ def gen_eval_wrapper(
191184            tokenizer = tokenizer ,
192185            # Exported model takes at most (max_seq_length - 1) tokens. 
193186            # Note that the eager model takes at most max_seq_length tokens. 
194-             max_seq_length = args .max_seq_length  -  1 ,
187+             max_seq_length = config . sequence .max_seq_length  -  1 ,
195188        )
196189
197-     pt2e_quant_params , quantizers , quant_dtype  =  get_quantizer_and_quant_params (args )
190+     pt2e_quant_params , quantizers , quant_dtype  =  get_quantizer_and_quant_params (config )
198191    # GPTFastEvalWrapper: Create a wrapper around a pre-exported model 
199-     manager : LLMEdgeManager  =  _prepare_for_llama_export (args )
192+     manager : LLMEdgeManager  =  _prepare_for_llama_export (config )
200193
201194    if  len (quantizers ) !=  0 :
202195        manager  =  manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +201,9 @@ def gen_eval_wrapper(
208201        return  GraphModuleEvalWrapper (
209202            model = model ,
210203            tokenizer = tokenizer ,
211-             max_seq_length = args .max_seq_length ,
212-             use_kv_cache = args . use_kv_cache ,   # pyre-ignore 
213-             enable_dynamic_shape = args . enable_dynamic_shape ,   # pyre-ignore 
204+             max_seq_length = config . sequence .max_seq_length ,
205+             use_kv_cache = config . kv_cache . use_kv_cache ,
206+             enable_dynamic_shape = config . misc . enable_dynamic_shape ,
214207        )
215208    else :
216209        # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch 
@@ -228,18 +221,94 @@ def gen_eval_wrapper(
228221        # that is not available in this eval_llama. We save the checkpoint 
229222        # here for consistency with eval_llama. The accuracy results we 
230223        # get from eval_llama can be used as a reference to other evaluations. 
231-         if  args . output_eager_checkpoint_file  is  not   None :   # pyre-ignore 
232-             torch .save (model , args .output_eager_checkpoint_file )
224+         if  config . eval . output_eager_checkpoint_file  is  not   None :
225+             torch .save (model , config . eval .output_eager_checkpoint_file )
233226
234227        return  EagerEvalWrapper (
235228            model = model ,
236229            tokenizer = tokenizer ,
237-             max_seq_length = args .max_seq_length ,
238-             use_kv_cache = args .use_kv_cache ,
230+             max_seq_length = config .sequence .max_seq_length ,
231+             use_kv_cache = config .kv_cache .use_kv_cache ,
232+         )
233+ 
234+ 
235+ def  eval_llama (
236+     model_name : str ,
237+     config : DictConfig ,
238+ ) ->  None :
239+     # Generate the eval wrapper 
240+     eval_wrapper  =  gen_eval_wrapper (model_name , config )
241+ 
242+     # Needed for loading mmlu dataset. 
243+     # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files 
244+     if  config .eval .tasks  and  "mmlu"  in  config .eval .tasks :
245+         import  datasets 
246+ 
247+         datasets .config .HF_DATASETS_TRUST_REMOTE_CODE  =  True 
248+ 
249+     # Evaluate the model 
250+     tasks  =  (
251+         None  if  config .eval .tasks  is  None  else  OmegaConf .to_container (config .eval .tasks )
252+     )
253+     with  torch .no_grad ():
254+         eval_results  =  simple_evaluate (
255+             model = eval_wrapper ,
256+             tasks = tasks ,
257+             num_fewshot = config .eval .num_fewshot ,
258+             limit = config .eval .limit ,
239259        )
240260
261+     for  task , res  in  eval_results ["results" ].items ():
262+         print (f"{ task }  : { res }  " )
263+ 
264+ 
265+ def  eval_llama_with_attention_sink (
266+     model_name : str ,
267+     config : DictConfig ,
268+ ) ->  None :
269+     # Generate the eval wrapper 
270+     eval_wrapper  =  gen_eval_wrapper (model_name , config )
271+ 
272+     # Needed for loading mmlu dataset. 
273+     # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files 
274+     if  config .eval .tasks  and  "mmlu"  in  config .eval .tasks :
275+         import  datasets 
276+ 
277+         datasets .config .HF_DATASETS_TRUST_REMOTE_CODE  =  True 
278+ 
279+     # Evaluate the model 
280+     with  torch .no_grad ():
281+         eval_results  =  simple_evaluate (
282+             model = eval_wrapper ,
283+             tasks = OmegaConf .to_container (config .eval .tasks ),
284+             num_fewshot = config .eval .num_fewshot ,
285+             limit = config .eval .limit ,
286+         )
287+ 
288+     for  task , res  in  eval_results ["results" ].items ():
289+         print (f"{ task }  : { res }  " )
290+ 
291+ 
292+ def  _convert_cli_to_config_format (args ) ->  DictConfig :
293+     """Convert CLI arguments to config format.""" 
294+     # First convert common args using the shared function 
295+     config  =  _convert_args_to_config (args )
296+ 
297+     # Add evaluation-specific settings 
298+     config .eval  =  OmegaConf .create ()
299+     config .eval .tasks  =  args .tasks 
300+     config .eval .limit  =  args .limit 
301+     config .eval .num_fewshot  =  args .num_fewshot 
302+     config .eval .pte  =  args .pte 
303+     config .eval .tokenizer_bin  =  args .tokenizer_bin 
304+     config .eval .output_eager_checkpoint_file  =  args .output_eager_checkpoint_file 
305+     config .eval .attention_sink_eval_tokens  =  args .attention_sink_eval_tokens 
306+ 
307+     return  config 
308+ 
241309
242310def  build_args_parser () ->  argparse .ArgumentParser :
311+     """Build argument parser for evaluation, extending the export parser with eval-specific args.""" 
243312    # Start with arg parser from export_llama_lib 
244313    parser  =  _build_args_parser ()
245314
@@ -286,92 +355,7 @@ def build_args_parser() -> argparse.ArgumentParser:
286355        help = "Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint." ,
287356    )
288357
289-     # Set of parameters secpific  to AttentionSink. 
358+     # Set of parameters specific  to AttentionSink. 
290359    parser .add_argument ("--attention_sink_eval_tokens" , type = int , default = 0 )
291360
292361    return  parser 
293- 
294- 
295- def  eval_llama (
296-     model_name : str ,
297-     args : argparse .ArgumentParser ,
298- ) ->  None :
299-     # Generate the eval wrapper 
300-     eval_wrapper  =  gen_eval_wrapper (model_name , args )
301- 
302-     # Needed for loading mmlu dataset. 
303-     # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files 
304-     # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` 
305-     if  args .tasks  and  "mmlu"  in  args .tasks :
306-         import  datasets 
307- 
308-         datasets .config .HF_DATASETS_TRUST_REMOTE_CODE  =  True 
309- 
310-     # Evaluate the model 
311-     with  torch .no_grad ():
312-         eval_results  =  simple_evaluate (
313-             model = eval_wrapper ,
314-             tasks = args .tasks ,
315-             num_fewshot = args .num_fewshot ,  # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` 
316-             limit = args .limit ,  # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` 
317-         )
318- 
319-     for  task , res  in  eval_results ["results" ].items ():
320-         print (f"{ task }  : { res }  " )
321- 
322- 
323- def  eval_llama_with_attention_sink (model_name : str , args : argparse .ArgumentParser ):
324-     """ 
325-     Evaluate the model's perplexity when AttentionSink is enabled. 
326- 
327-     This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py 
328-     """ 
329-     assert  args .use_attention_sink  is  not   None   # pyre-ignore [16] 
330-     assert  args .attention_sink_eval_tokens  >  0   # pyre-ignore [16] 
331-     attention_sink_params  =  args .use_attention_sink .split ("," )
332-     assert  len (attention_sink_params ) ==  3 
333-     sink_size  =  int (attention_sink_params [0 ])
334-     window_size  =  int (attention_sink_params [1 ])
335- 
336-     assert  args .max_seq_length  ==  sink_size  +  window_size   # pyre-ignore [16] 
337- 
338-     device  =  "cuda"  if  torch .cuda .is_available () else  "cpu" 
339-     manager : LLMEdgeManager  =  _prepare_for_llama_export (args )
340-     model  =  manager .model .eval ().to (device = device )
341-     tokenizer  =  get_tokenizer (args .tokenizer_path )  # pyre-ignore [16] 
342- 
343-     eval_data  =  load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344- 
345-     nlls  =  []
346-     loss_fn  =  CrossEntropyLoss (reduction = "none" )
347-     progress_bar  =  tqdm (total = args .attention_sink_eval_tokens )
348-     input_pos  =  0 
349-     while  input_pos  <  args .attention_sink_eval_tokens :
350-         for  text  in  eval_data ["text" ]:  # pyre-ignore [16] 
351-             tokens  =  tokenizer .encode (text , bos = False , eos = False )
352-             if  len (tokens ) <=  0 :
353-                 continue 
354-             with  torch .no_grad ():
355-                 num_tokens  =  min (
356-                     len (tokens ) -  1 , args .attention_sink_eval_tokens  -  input_pos 
357-                 )
358-                 logits  =  model (
359-                     torch .tensor (
360-                         [tokens [:num_tokens ]], dtype = torch .int64 , device = device 
361-                     ),
362-                     torch .tensor ([input_pos ], dtype = torch .int64 , device = device ),
363-                 ).squeeze (dim = 0 )
364-                 neg_log_likelihood  =  loss_fn (
365-                     logits ,
366-                     torch .tensor (
367-                         [tokens [1  : num_tokens  +  1 ]], dtype = torch .int64 , device = device 
368-                     ).view (- 1 ),
369-                 )
370-                 nlls .append (neg_log_likelihood )
371-                 input_pos  +=  num_tokens 
372-                 progress_bar .update (num_tokens )
373-             if  input_pos  >=  args .attention_sink_eval_tokens :
374-                 break 
375-     ppl  =  torch .exp (torch .cat (nlls ).mean ())
376-     print (f"Perplexity: { ppl .item ()}  " )
377-     return  ppl .item ()
0 commit comments