Skip to content

Commit 1e32696

Browse files
Martin Yuanfacebook-github-bot
Martin Yuan
authored andcommitted
Convert Args to DictConfig (#9717)
Summary: Pull Request resolved: #9717 1. Add structure to the complicated args 2. Convert args to DictConfig, to decouple the cli args 3. Pass needed sub configs to functions (instead of args) Pull Request resolved: #9450 Test Plan: ``` python3 -m examples.models.llama.export_llama -c stories110M.pt -p params.json -d fp32 -n tinyllama_xnnpack+custom_fp32_main.pte -kv -X --xnnpack-extended-ops -qmode 8da4w -G 128 --use_sdpa_with_kv_cache --output-dir tmp -E 8,64 ``` Reviewed By: larryliu0820 Differential Revision: D71557301 Pulled By: jackzhxng
1 parent d72ef5b commit 1e32696

File tree

7 files changed

+454
-274
lines changed

7 files changed

+454
-274
lines changed

examples/models/llama/eval_llama.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from .eval_llama_lib import (
14+
_convert_cli_to_config_format,
1415
build_args_parser,
1516
eval_llama,
1617
eval_llama_with_attention_sink,
@@ -28,10 +29,11 @@ def main() -> None:
2829
args = parser.parse_args()
2930
# Overrides this arg, because evaluation requires full logits.
3031
args.generate_full_logits = True
32+
config = _convert_cli_to_config_format(args)
3133
if args.use_attention_sink:
32-
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
34+
eval_llama_with_attention_sink(modelname, config) # pyre-ignore
3335
else:
34-
eval_llama(modelname, args) # pyre-ignore
36+
eval_llama(modelname, config) # pyre-ignore
3537

3638

3739
if __name__ == "__main__":

examples/models/llama/eval_llama_lib.py

+96-113
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,27 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
87
import argparse
9-
108
from typing import Optional, Union
119

1210
import torch
13-
14-
from datasets import load_dataset
1511
from executorch.examples.models.llama.export_llama_lib import (
12+
_convert_args_to_config,
13+
_prepare_for_llama_export,
14+
build_args_parser as _build_args_parser,
1615
get_quantizer_and_quant_params,
1716
)
1817
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
19-
2018
from executorch.extension.llm.export.builder import LLMEdgeManager
2119
from executorch.extension.llm.tokenizer.tokenizer import (
2220
Tokenizer as SentencePieceTokenizer,
2321
)
2422
from executorch.extension.llm.tokenizer.utils import get_tokenizer
2523
from lm_eval.evaluator import simple_evaluate
26-
from torch.nn import CrossEntropyLoss
27-
from tqdm import tqdm
24+
from omegaconf import DictConfig, OmegaConf
2825

2926
from .evaluate.eager_eval import EagerEvalWrapper
3027

31-
from .export_llama_lib import (
32-
_prepare_for_llama_export,
33-
build_args_parser as _build_args_parser,
34-
)
35-
3628

3729
class GraphModuleEvalWrapper(EagerEvalWrapper):
3830
"""
@@ -165,7 +157,7 @@ def _model_call(self, inps):
165157

166158
def gen_eval_wrapper(
167159
model_name: str,
168-
args: argparse.ArgumentParser,
160+
config: DictConfig,
169161
):
170162
"""
171163
Generates a wrapper interface around the provided model and tokenizer for
@@ -174,17 +166,17 @@ def gen_eval_wrapper(
174166
Returns:
175167
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
176168
"""
177-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
169+
tokenizer = get_tokenizer(config.model.tokenizer_path)
178170

179171
# ExecuTorch Binary Evaluation
180-
if (model := args.pte) is not None: # pyre-ignore
181-
if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore
172+
if (model := config.eval.pte) is not None:
173+
if (tokenizer_bin := config.eval.tokenizer_bin) is not None:
182174
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
183175
return ETRunnerEvalWrapper(
184176
model=model,
185177
tokenizer=tokenizer,
186178
tokenizer_bin=tokenizer_bin,
187-
max_seq_length=args.max_seq_length, # pyre-ignore
179+
max_seq_length=config.sequence.max_seq_length,
188180
)
189181

190182
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -193,12 +185,12 @@ def gen_eval_wrapper(
193185
tokenizer=tokenizer,
194186
# Exported model takes at most (max_seq_length - 1) tokens.
195187
# Note that the eager model takes at most max_seq_length tokens.
196-
max_seq_length=args.max_seq_length - 1,
188+
max_seq_length=config.sequence.max_seq_length - 1,
197189
)
198190

199-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
191+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(config)
200192
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
201-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
193+
manager: LLMEdgeManager = _prepare_for_llama_export(config)
202194

203195
if len(quantizers) != 0:
204196
manager = manager.export().pt2e_quantize(quantizers)
@@ -210,9 +202,9 @@ def gen_eval_wrapper(
210202
return GraphModuleEvalWrapper(
211203
model=model,
212204
tokenizer=tokenizer,
213-
max_seq_length=args.max_seq_length,
214-
use_kv_cache=args.use_kv_cache, # pyre-ignore
215-
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
205+
max_seq_length=config.sequence.max_seq_length,
206+
use_kv_cache=config.kv_cache.use_kv_cache, # pyre-ignore
207+
enable_dynamic_shape=config.misc.enable_dynamic_shape, # pyre-ignore
216208
)
217209
else:
218210
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -230,18 +222,94 @@ def gen_eval_wrapper(
230222
# that is not available in this eval_llama. We save the checkpoint
231223
# here for consistency with eval_llama. The accuracy results we
232224
# get from eval_llama can be used as a reference to other evaluations.
233-
if args.output_eager_checkpoint_file is not None: # pyre-ignore
234-
torch.save(model, args.output_eager_checkpoint_file)
225+
if config.eval.output_eager_checkpoint_file is not None: # pyre-ignore
226+
torch.save(model, config.eval.output_eager_checkpoint_file)
235227

236228
return EagerEvalWrapper(
237229
model=model,
238230
tokenizer=tokenizer,
239-
max_seq_length=args.max_seq_length,
240-
use_kv_cache=args.use_kv_cache,
231+
max_seq_length=config.sequence.max_seq_length,
232+
use_kv_cache=config.kv_cache.use_kv_cache,
233+
)
234+
235+
236+
def eval_llama(
237+
model_name: str,
238+
config: DictConfig,
239+
) -> None:
240+
# Generate the eval wrapper
241+
eval_wrapper = gen_eval_wrapper(model_name, config)
242+
243+
# Needed for loading mmlu dataset.
244+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
245+
if config.eval.tasks and "mmlu" in config.eval.tasks:
246+
import datasets
247+
248+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
249+
250+
# Evaluate the model
251+
tasks = (
252+
None if config.eval.tasks is None else OmegaConf.to_container(config.eval.tasks)
253+
)
254+
with torch.no_grad():
255+
eval_results = simple_evaluate(
256+
model=eval_wrapper,
257+
tasks=tasks,
258+
num_fewshot=config.eval.num_fewshot,
259+
limit=config.eval.limit,
260+
)
261+
262+
for task, res in eval_results["results"].items():
263+
print(f"{task}: {res}")
264+
265+
266+
def eval_llama_with_attention_sink(
267+
model_name: str,
268+
config: DictConfig,
269+
) -> None:
270+
# Generate the eval wrapper
271+
eval_wrapper = gen_eval_wrapper(model_name, config)
272+
273+
# Needed for loading mmlu dataset.
274+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
275+
if config.eval.tasks and "mmlu" in config.eval.tasks:
276+
import datasets
277+
278+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
279+
280+
# Evaluate the model
281+
with torch.no_grad():
282+
eval_results = simple_evaluate(
283+
model=eval_wrapper,
284+
tasks=OmegaConf.to_container(config.eval.tasks),
285+
num_fewshot=config.eval.num_fewshot,
286+
limit=config.eval.limit,
241287
)
242288

289+
for task, res in eval_results["results"].items():
290+
print(f"{task}: {res}")
291+
292+
293+
def _convert_cli_to_config_format(args) -> DictConfig:
294+
"""Convert CLI arguments to config format."""
295+
# First convert common args using the shared function
296+
config = _convert_args_to_config(args)
297+
298+
# Add evaluation-specific settings
299+
config.eval = OmegaConf.create()
300+
config.eval.tasks = args.tasks
301+
config.eval.limit = args.limit
302+
config.eval.num_fewshot = args.num_fewshot
303+
config.eval.pte = args.pte
304+
config.eval.tokenizer_bin = args.tokenizer_bin
305+
config.eval.output_eager_checkpoint_file = args.output_eager_checkpoint_file
306+
config.eval.attention_sink_eval_tokens = args.attention_sink_eval_tokens
307+
308+
return config
309+
243310

244311
def build_args_parser() -> argparse.ArgumentParser:
312+
"""Build argument parser for evaluation, extending the export parser with eval-specific args."""
245313
# Start with arg parser from export_llama_lib
246314
parser = _build_args_parser()
247315

@@ -288,92 +356,7 @@ def build_args_parser() -> argparse.ArgumentParser:
288356
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
289357
)
290358

291-
# Set of parameters secpific to AttentionSink.
359+
# Set of parameters specific to AttentionSink.
292360
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)
293361

294362
return parser
295-
296-
297-
def eval_llama(
298-
model_name: str,
299-
args: argparse.ArgumentParser,
300-
) -> None:
301-
# Generate the eval wrapper
302-
eval_wrapper = gen_eval_wrapper(model_name, args)
303-
304-
# Needed for loading mmlu dataset.
305-
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
306-
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
307-
if args.tasks and "mmlu" in args.tasks:
308-
import datasets
309-
310-
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
311-
312-
# Evaluate the model
313-
with torch.no_grad():
314-
eval_results = simple_evaluate(
315-
model=eval_wrapper,
316-
tasks=args.tasks,
317-
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
318-
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
319-
)
320-
321-
for task, res in eval_results["results"].items():
322-
print(f"{task}: {res}")
323-
324-
325-
def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
326-
"""
327-
Evaluate the model's perplexity when AttentionSink is enabled.
328-
329-
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
330-
"""
331-
assert args.use_attention_sink is not None # pyre-ignore [16]
332-
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
333-
attention_sink_params = args.use_attention_sink.split(",")
334-
assert len(attention_sink_params) == 3
335-
sink_size = int(attention_sink_params[0])
336-
window_size = int(attention_sink_params[1])
337-
338-
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
339-
340-
device = "cuda" if torch.cuda.is_available() else "cpu"
341-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
342-
model = manager.model.eval().to(device=device)
343-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
344-
345-
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
346-
347-
nlls = []
348-
loss_fn = CrossEntropyLoss(reduction="none")
349-
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
350-
input_pos = 0
351-
while input_pos < args.attention_sink_eval_tokens:
352-
for text in eval_data["text"]: # pyre-ignore [16]
353-
tokens = tokenizer.encode(text, bos=False, eos=False)
354-
if len(tokens) <= 0:
355-
continue
356-
with torch.no_grad():
357-
num_tokens = min(
358-
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
359-
)
360-
logits = model(
361-
torch.tensor(
362-
[tokens[:num_tokens]], dtype=torch.int64, device=device
363-
),
364-
torch.tensor([input_pos], dtype=torch.int64, device=device),
365-
).squeeze(dim=0)
366-
neg_log_likelihood = loss_fn(
367-
logits,
368-
torch.tensor(
369-
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
370-
).view(-1),
371-
)
372-
nlls.append(neg_log_likelihood)
373-
input_pos += num_tokens
374-
progress_bar.update(num_tokens)
375-
if input_pos >= args.attention_sink_eval_tokens:
376-
break
377-
ppl = torch.exp(torch.cat(nlls).mean())
378-
print(f"Perplexity: {ppl.item()}")
379-
return ppl.item()

0 commit comments

Comments
 (0)