Skip to content

Commit caf3f18

Browse files
Martin Yuanfacebook-github-bot
Martin Yuan
authored andcommitted
Convert Args to DictConfig (#9717)
Summary: 1. Add structure to the complicated args 2. Convert args to DictConfig, to decouple the cli args 3. Pass needed sub configs to functions (instead of args) Test Plan: ``` python3 -m examples.models.llama.export_llama -c stories110M.pt -p params.json -d fp32 -n tinyllama_xnnpack+custom_fp32_main.pte -kv -X --xnnpack-extended-ops -qmode 8da4w -G 128 --use_sdpa_with_kv_cache --output-dir tmp -E 8,64 ``` Reviewed By: larryliu0820 Differential Revision: D71557301 Pulled By: jackzhxng
1 parent d72ef5b commit caf3f18

File tree

7 files changed

+454
-274
lines changed

7 files changed

+454
-274
lines changed

examples/models/llama/eval_llama.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from .eval_llama_lib import (
14+
_convert_cli_to_config_format,
1415
build_args_parser,
1516
eval_llama,
1617
eval_llama_with_attention_sink,
@@ -28,10 +29,11 @@ def main() -> None:
2829
args = parser.parse_args()
2930
# Overrides this arg, because evaluation requires full logits.
3031
args.generate_full_logits = True
32+
config = _convert_cli_to_config_format(args)
3133
if args.use_attention_sink:
32-
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
34+
eval_llama_with_attention_sink(modelname, config) # pyre-ignore
3335
else:
34-
eval_llama(modelname, args) # pyre-ignore
36+
eval_llama(modelname, config) # pyre-ignore
3537

3638

3739
if __name__ == "__main__":

examples/models/llama/eval_llama_lib.py

+96-113
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,27 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
87
import argparse
9-
108
from typing import Optional, Union
119

1210
import torch
13-
14-
from datasets import load_dataset
1511
from executorch.examples.models.llama.export_llama_lib import (
12+
_convert_args_to_config,
13+
_prepare_for_llama_export,
14+
build_args_parser as _build_args_parser,
1615
get_quantizer_and_quant_params,
1716
)
1817
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
19-
2018
from executorch.extension.llm.export.builder import LLMEdgeManager
2119
from executorch.extension.llm.tokenizer.tokenizer import (
2220
Tokenizer as SentencePieceTokenizer,
2321
)
2422
from executorch.extension.llm.tokenizer.utils import get_tokenizer
2523
from lm_eval.evaluator import simple_evaluate
26-
from torch.nn import CrossEntropyLoss
27-
from tqdm import tqdm
24+
from omegaconf import DictConfig, OmegaConf
2825

2926
from .evaluate.eager_eval import EagerEvalWrapper
3027

31-
from .export_llama_lib import (
32-
_prepare_for_llama_export,
33-
build_args_parser as _build_args_parser,
34-
)
35-
3628

3729
class GraphModuleEvalWrapper(EagerEvalWrapper):
3830
"""
@@ -165,7 +157,7 @@ def _model_call(self, inps):
165157

166158
def gen_eval_wrapper(
167159
model_name: str,
168-
args: argparse.ArgumentParser,
160+
config: DictConfig,
169161
):
170162
"""
171163
Generates a wrapper interface around the provided model and tokenizer for
@@ -174,17 +166,17 @@ def gen_eval_wrapper(
174166
Returns:
175167
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
176168
"""
177-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
169+
tokenizer = get_tokenizer(config.model.tokenizer_path)
178170

179171
# ExecuTorch Binary Evaluation
180-
if (model := args.pte) is not None: # pyre-ignore
181-
if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore
172+
if (model := config.eval.pte) is not None:
173+
if (tokenizer_bin := config.eval.tokenizer_bin) is not None:
182174
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
183175
return ETRunnerEvalWrapper(
184176
model=model,
185177
tokenizer=tokenizer,
186178
tokenizer_bin=tokenizer_bin,
187-
max_seq_length=args.max_seq_length, # pyre-ignore
179+
max_seq_length=config.sequence.max_seq_length,
188180
)
189181

190182
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -193,12 +185,12 @@ def gen_eval_wrapper(
193185
tokenizer=tokenizer,
194186
# Exported model takes at most (max_seq_length - 1) tokens.
195187
# Note that the eager model takes at most max_seq_length tokens.
196-
max_seq_length=args.max_seq_length - 1,
188+
max_seq_length=config.sequence.max_seq_length - 1,
197189
)
198190

199-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
191+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(config)
200192
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
201-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
193+
manager: LLMEdgeManager = _prepare_for_llama_export(config)
202194

203195
if len(quantizers) != 0:
204196
manager = manager.export().pt2e_quantize(quantizers)
@@ -210,9 +202,9 @@ def gen_eval_wrapper(
210202
return GraphModuleEvalWrapper(
211203
model=model,
212204
tokenizer=tokenizer,
213-
max_seq_length=args.max_seq_length,
214-
use_kv_cache=args.use_kv_cache, # pyre-ignore
215-
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
205+
max_seq_length=config.sequence.max_seq_length,
206+
use_kv_cache=config.kv_cache.use_kv_cache, # pyre-ignore
207+
enable_dynamic_shape=config.misc.enable_dynamic_shape, # pyre-ignore
216208
)
217209
else:
218210
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -230,18 +222,94 @@ def gen_eval_wrapper(
230222
# that is not available in this eval_llama. We save the checkpoint
231223
# here for consistency with eval_llama. The accuracy results we
232224
# get from eval_llama can be used as a reference to other evaluations.
233-
if args.output_eager_checkpoint_file is not None: # pyre-ignore
234-
torch.save(model, args.output_eager_checkpoint_file)
225+
if config.eval.output_eager_checkpoint_file is not None: # pyre-ignore
226+
torch.save(model, config.eval.output_eager_checkpoint_file)
235227

236228
return EagerEvalWrapper(
237229
model=model,
238230
tokenizer=tokenizer,
239-
max_seq_length=args.max_seq_length,
240-
use_kv_cache=args.use_kv_cache,
231+
max_seq_length=config.sequence.max_seq_length,
232+
use_kv_cache=config.kv_cache.use_kv_cache,
233+
)
234+
235+
236+
def eval_llama(
237+
model_name: str,
238+
config: DictConfig,
239+
) -> None:
240+
# Generate the eval wrapper
241+
eval_wrapper = gen_eval_wrapper(model_name, config)
242+
243+
# Needed for loading mmlu dataset.
244+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
245+
if config.eval.tasks and "mmlu" in config.eval.tasks:
246+
import datasets
247+
248+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
249+
250+
# Evaluate the model
251+
tasks = (
252+
None if config.eval.tasks is None else OmegaConf.to_container(config.eval.tasks)
253+
)
254+
with torch.no_grad():
255+
eval_results = simple_evaluate(
256+
model=eval_wrapper,
257+
tasks=tasks,
258+
num_fewshot=config.eval.num_fewshot,
259+
limit=config.eval.limit,
260+
)
261+
262+
for task, res in eval_results["results"].items():
263+
print(f"{task}: {res}")
264+
265+
266+
def eval_llama_with_attention_sink(
267+
model_name: str,
268+
config: DictConfig,
269+
) -> None:
270+
# Generate the eval wrapper
271+
eval_wrapper = gen_eval_wrapper(model_name, config)
272+
273+
# Needed for loading mmlu dataset.
274+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
275+
if config.eval.tasks and "mmlu" in config.eval.tasks:
276+
import datasets
277+
278+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
279+
280+
# Evaluate the model
281+
with torch.no_grad():
282+
eval_results = simple_evaluate(
283+
model=eval_wrapper,
284+
tasks=OmegaConf.to_container(config.eval.tasks),
285+
num_fewshot=config.eval.num_fewshot,
286+
limit=config.eval.limit,
241287
)
242288

289+
for task, res in eval_results["results"].items():
290+
print(f"{task}: {res}")
291+
292+
293+
def _convert_cli_to_config_format(args) -> DictConfig:
294+
"""Convert CLI arguments to config format."""
295+
# First convert common args using the shared function
296+
config = _convert_args_to_config(args)
297+
298+
# Add evaluation-specific settings
299+
config.eval = OmegaConf.create()
300+
config.eval.tasks = args.tasks
301+
config.eval.limit = args.limit
302+
config.eval.num_fewshot = args.num_fewshot
303+
config.eval.pte = args.pte
304+
config.eval.tokenizer_bin = args.tokenizer_bin
305+
config.eval.output_eager_checkpoint_file = args.output_eager_checkpoint_file
306+
config.eval.attention_sink_eval_tokens = args.attention_sink_eval_tokens
307+
308+
return config
309+
243310

244311
def build_args_parser() -> argparse.ArgumentParser:
312+
"""Build argument parser for evaluation, extending the export parser with eval-specific args."""
245313
# Start with arg parser from export_llama_lib
246314
parser = _build_args_parser()
247315

@@ -288,92 +356,7 @@ def build_args_parser() -> argparse.ArgumentParser:
288356
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
289357
)
290358

291-
# Set of parameters secpific to AttentionSink.
359+
# Set of parameters specific to AttentionSink.
292360
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)
293361

294362
return parser
295-
296-
297-
def eval_llama(
298-
model_name: str,
299-
args: argparse.ArgumentParser,
300-
) -> None:
301-
# Generate the eval wrapper
302-
eval_wrapper = gen_eval_wrapper(model_name, args)
303-
304-
# Needed for loading mmlu dataset.
305-
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
306-
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
307-
if args.tasks and "mmlu" in args.tasks:
308-
import datasets
309-
310-
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
311-
312-
# Evaluate the model
313-
with torch.no_grad():
314-
eval_results = simple_evaluate(
315-
model=eval_wrapper,
316-
tasks=args.tasks,
317-
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
318-
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
319-
)
320-
321-
for task, res in eval_results["results"].items():
322-
print(f"{task}: {res}")
323-
324-
325-
def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
326-
"""
327-
Evaluate the model's perplexity when AttentionSink is enabled.
328-
329-
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
330-
"""
331-
assert args.use_attention_sink is not None # pyre-ignore [16]
332-
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
333-
attention_sink_params = args.use_attention_sink.split(",")
334-
assert len(attention_sink_params) == 3
335-
sink_size = int(attention_sink_params[0])
336-
window_size = int(attention_sink_params[1])
337-
338-
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
339-
340-
device = "cuda" if torch.cuda.is_available() else "cpu"
341-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
342-
model = manager.model.eval().to(device=device)
343-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
344-
345-
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
346-
347-
nlls = []
348-
loss_fn = CrossEntropyLoss(reduction="none")
349-
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
350-
input_pos = 0
351-
while input_pos < args.attention_sink_eval_tokens:
352-
for text in eval_data["text"]: # pyre-ignore [16]
353-
tokens = tokenizer.encode(text, bos=False, eos=False)
354-
if len(tokens) <= 0:
355-
continue
356-
with torch.no_grad():
357-
num_tokens = min(
358-
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
359-
)
360-
logits = model(
361-
torch.tensor(
362-
[tokens[:num_tokens]], dtype=torch.int64, device=device
363-
),
364-
torch.tensor([input_pos], dtype=torch.int64, device=device),
365-
).squeeze(dim=0)
366-
neg_log_likelihood = loss_fn(
367-
logits,
368-
torch.tensor(
369-
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
370-
).view(-1),
371-
)
372-
nlls.append(neg_log_likelihood)
373-
input_pos += num_tokens
374-
progress_bar.update(num_tokens)
375-
if input_pos >= args.attention_sink_eval_tokens:
376-
break
377-
ppl = torch.exp(torch.cat(nlls).mean())
378-
print(f"Perplexity: {ppl.item()}")
379-
return ppl.item()

0 commit comments

Comments
 (0)