Skip to content

Commit 4054c53

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Convert Args to DictConfig (#9450) (#9717)
Summary: 1. Add structure to the complicated args 2. Convert args to DictConfig, to decouple the cli args 3. Pass needed sub configs to functions (instead of args) Test Plan: ``` python3 -m examples.models.llama.export_llama -c stories110M.pt -p params.json -d fp32 -n tinyllama_xnnpack+custom_fp32_main.pte -kv -X --xnnpack-extended-ops -qmode 8da4w -G 128 --use_sdpa_with_kv_cache --output-dir tmp -E 8,64 ``` Pulled By: jackzhxng jackzhxng Differential Revision: D71557301
1 parent eef0010 commit 4054c53

File tree

11 files changed

+554
-320
lines changed

11 files changed

+554
-320
lines changed

Diff for: examples/models/llama/eval_llama.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from .eval_llama_lib import (
14+
_convert_cli_to_config_format,
1415
build_args_parser,
1516
eval_llama,
1617
eval_llama_with_attention_sink,
@@ -28,10 +29,11 @@ def main() -> None:
2829
args = parser.parse_args()
2930
# Overrides this arg, because evaluation requires full logits.
3031
args.generate_full_logits = True
32+
config = _convert_cli_to_config_format(args)
3133
if args.use_attention_sink:
32-
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
34+
eval_llama_with_attention_sink(modelname, config)
3335
else:
34-
eval_llama(modelname, args) # pyre-ignore
36+
eval_llama(modelname, config)
3537

3638

3739
if __name__ == "__main__":

Diff for: examples/models/llama/eval_llama_lib.py

+96-112
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,26 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
87
import argparse
9-
108
from typing import Optional, Union
119

1210
import torch
13-
14-
from datasets import load_dataset
1511
from executorch.examples.models.llama.export_llama_lib import (
12+
_convert_args_to_config,
13+
_prepare_for_llama_export,
14+
build_args_parser as _build_args_parser,
1615
get_quantizer_and_quant_params,
1716
)
1817

1918
from executorch.extension.llm.export.builder import LLMEdgeManager
2019
from lm_eval.evaluator import simple_evaluate
20+
from omegaconf import DictConfig, OmegaConf
2121
from pytorch_tokenizers import get_tokenizer
2222
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
2323
from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken
24-
from torch.nn import CrossEntropyLoss
25-
from tqdm import tqdm
2624

2725
from .evaluate.eager_eval import EagerEvalWrapper
2826

29-
from .export_llama_lib import (
30-
_prepare_for_llama_export,
31-
build_args_parser as _build_args_parser,
32-
)
33-
3427

3528
class GraphModuleEvalWrapper(EagerEvalWrapper):
3629
"""
@@ -163,7 +156,7 @@ def _model_call(self, inps):
163156

164157
def gen_eval_wrapper(
165158
model_name: str,
166-
args: argparse.ArgumentParser,
159+
config: DictConfig,
167160
):
168161
"""
169162
Generates a wrapper interface around the provided model and tokenizer for
@@ -172,17 +165,17 @@ def gen_eval_wrapper(
172165
Returns:
173166
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174167
"""
175-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
168+
tokenizer = get_tokenizer(config.model.tokenizer_path)
176169

177170
# ExecuTorch Binary Evaluation
178-
if (model := args.pte) is not None: # pyre-ignore
179-
if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore
171+
if (model := config.eval.pte) is not None:
172+
if (tokenizer_bin := config.eval.tokenizer_bin) is not None:
180173
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
181174
return ETRunnerEvalWrapper(
182175
model=model,
183176
tokenizer=tokenizer,
184177
tokenizer_bin=tokenizer_bin,
185-
max_seq_length=args.max_seq_length, # pyre-ignore
178+
max_seq_length=config.sequence.max_seq_length,
186179
)
187180

188181
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +184,12 @@ def gen_eval_wrapper(
191184
tokenizer=tokenizer,
192185
# Exported model takes at most (max_seq_length - 1) tokens.
193186
# Note that the eager model takes at most max_seq_length tokens.
194-
max_seq_length=args.max_seq_length - 1,
187+
max_seq_length=config.sequence.max_seq_length - 1,
195188
)
196189

197-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
190+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(config)
198191
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
192+
manager: LLMEdgeManager = _prepare_for_llama_export(config)
200193

201194
if len(quantizers) != 0:
202195
manager = manager.export().pt2e_quantize(quantizers)
@@ -208,9 +201,9 @@ def gen_eval_wrapper(
208201
return GraphModuleEvalWrapper(
209202
model=model,
210203
tokenizer=tokenizer,
211-
max_seq_length=args.max_seq_length,
212-
use_kv_cache=args.use_kv_cache, # pyre-ignore
213-
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
204+
max_seq_length=config.sequence.max_seq_length,
205+
use_kv_cache=config.kv_cache.use_kv_cache,
206+
enable_dynamic_shape=config.misc.enable_dynamic_shape,
214207
)
215208
else:
216209
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -228,18 +221,94 @@ def gen_eval_wrapper(
228221
# that is not available in this eval_llama. We save the checkpoint
229222
# here for consistency with eval_llama. The accuracy results we
230223
# get from eval_llama can be used as a reference to other evaluations.
231-
if args.output_eager_checkpoint_file is not None: # pyre-ignore
232-
torch.save(model, args.output_eager_checkpoint_file)
224+
if config.eval.output_eager_checkpoint_file is not None:
225+
torch.save(model, config.eval.output_eager_checkpoint_file)
233226

234227
return EagerEvalWrapper(
235228
model=model,
236229
tokenizer=tokenizer,
237-
max_seq_length=args.max_seq_length,
238-
use_kv_cache=args.use_kv_cache,
230+
max_seq_length=config.sequence.max_seq_length,
231+
use_kv_cache=config.kv_cache.use_kv_cache,
232+
)
233+
234+
235+
def eval_llama(
236+
model_name: str,
237+
config: DictConfig,
238+
) -> None:
239+
# Generate the eval wrapper
240+
eval_wrapper = gen_eval_wrapper(model_name, config)
241+
242+
# Needed for loading mmlu dataset.
243+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
244+
if config.eval.tasks and "mmlu" in config.eval.tasks:
245+
import datasets
246+
247+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
248+
249+
# Evaluate the model
250+
tasks = (
251+
None if config.eval.tasks is None else OmegaConf.to_container(config.eval.tasks)
252+
)
253+
with torch.no_grad():
254+
eval_results = simple_evaluate(
255+
model=eval_wrapper,
256+
tasks=tasks,
257+
num_fewshot=config.eval.num_fewshot,
258+
limit=config.eval.limit,
239259
)
240260

261+
for task, res in eval_results["results"].items():
262+
print(f"{task}: {res}")
263+
264+
265+
def eval_llama_with_attention_sink(
266+
model_name: str,
267+
config: DictConfig,
268+
) -> None:
269+
# Generate the eval wrapper
270+
eval_wrapper = gen_eval_wrapper(model_name, config)
271+
272+
# Needed for loading mmlu dataset.
273+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
274+
if config.eval.tasks and "mmlu" in config.eval.tasks:
275+
import datasets
276+
277+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
278+
279+
# Evaluate the model
280+
with torch.no_grad():
281+
eval_results = simple_evaluate(
282+
model=eval_wrapper,
283+
tasks=OmegaConf.to_container(config.eval.tasks),
284+
num_fewshot=config.eval.num_fewshot,
285+
limit=config.eval.limit,
286+
)
287+
288+
for task, res in eval_results["results"].items():
289+
print(f"{task}: {res}")
290+
291+
292+
def _convert_cli_to_config_format(args) -> DictConfig:
293+
"""Convert CLI arguments to config format."""
294+
# First convert common args using the shared function
295+
config = _convert_args_to_config(args)
296+
297+
# Add evaluation-specific settings
298+
config.eval = OmegaConf.create()
299+
config.eval.tasks = args.tasks
300+
config.eval.limit = args.limit
301+
config.eval.num_fewshot = args.num_fewshot
302+
config.eval.pte = args.pte
303+
config.eval.tokenizer_bin = args.tokenizer_bin
304+
config.eval.output_eager_checkpoint_file = args.output_eager_checkpoint_file
305+
config.eval.attention_sink_eval_tokens = args.attention_sink_eval_tokens
306+
307+
return config
308+
241309

242310
def build_args_parser() -> argparse.ArgumentParser:
311+
"""Build argument parser for evaluation, extending the export parser with eval-specific args."""
243312
# Start with arg parser from export_llama_lib
244313
parser = _build_args_parser()
245314

@@ -286,92 +355,7 @@ def build_args_parser() -> argparse.ArgumentParser:
286355
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
287356
)
288357

289-
# Set of parameters secpific to AttentionSink.
358+
# Set of parameters specific to AttentionSink.
290359
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)
291360

292361
return parser
293-
294-
295-
def eval_llama(
296-
model_name: str,
297-
args: argparse.ArgumentParser,
298-
) -> None:
299-
# Generate the eval wrapper
300-
eval_wrapper = gen_eval_wrapper(model_name, args)
301-
302-
# Needed for loading mmlu dataset.
303-
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304-
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305-
if args.tasks and "mmlu" in args.tasks:
306-
import datasets
307-
308-
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
309-
310-
# Evaluate the model
311-
with torch.no_grad():
312-
eval_results = simple_evaluate(
313-
model=eval_wrapper,
314-
tasks=args.tasks,
315-
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316-
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
317-
)
318-
319-
for task, res in eval_results["results"].items():
320-
print(f"{task}: {res}")
321-
322-
323-
def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
324-
"""
325-
Evaluate the model's perplexity when AttentionSink is enabled.
326-
327-
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328-
"""
329-
assert args.use_attention_sink is not None # pyre-ignore [16]
330-
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
331-
attention_sink_params = args.use_attention_sink.split(",")
332-
assert len(attention_sink_params) == 3
333-
sink_size = int(attention_sink_params[0])
334-
window_size = int(attention_sink_params[1])
335-
336-
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
337-
338-
device = "cuda" if torch.cuda.is_available() else "cpu"
339-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
340-
model = manager.model.eval().to(device=device)
341-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
342-
343-
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
344-
345-
nlls = []
346-
loss_fn = CrossEntropyLoss(reduction="none")
347-
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
348-
input_pos = 0
349-
while input_pos < args.attention_sink_eval_tokens:
350-
for text in eval_data["text"]: # pyre-ignore [16]
351-
tokens = tokenizer.encode(text, bos=False, eos=False)
352-
if len(tokens) <= 0:
353-
continue
354-
with torch.no_grad():
355-
num_tokens = min(
356-
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
357-
)
358-
logits = model(
359-
torch.tensor(
360-
[tokens[:num_tokens]], dtype=torch.int64, device=device
361-
),
362-
torch.tensor([input_pos], dtype=torch.int64, device=device),
363-
).squeeze(dim=0)
364-
neg_log_likelihood = loss_fn(
365-
logits,
366-
torch.tensor(
367-
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
368-
).view(-1),
369-
)
370-
nlls.append(neg_log_likelihood)
371-
input_pos += num_tokens
372-
progress_bar.update(num_tokens)
373-
if input_pos >= args.attention_sink_eval_tokens:
374-
break
375-
ppl = torch.exp(torch.cat(nlls).mean())
376-
print(f"Perplexity: {ppl.item()}")
377-
return ppl.item()

0 commit comments

Comments
 (0)