Skip to content

Commit 6887e93

Browse files
committed
Fix tests and lint
1 parent d9b549f commit 6887e93

File tree

4 files changed

+38
-24
lines changed

4 files changed

+38
-24
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,11 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
534534
return return_val
535535

536536

537+
def get_default_llm_config() -> DictConfig:
538+
default_args = build_args_parser().parse_args([])
539+
return _convert_args_to_config(default_args)
540+
541+
537542
def _convert_args_to_config(args: argparse.Namespace) -> DictConfig:
538543
"""Convert argparse.Namespace to DictConfig."""
539544
# Create a dictionary from args
@@ -670,7 +675,9 @@ def export_llama(args: Union[argparse.Namespace, DictConfig]) -> str:
670675
raise ValueError(
671676
f"Converting weights to meta format for {config.model.name} is not yet supported"
672677
)
673-
config.model.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights)
678+
config.model.checkpoint = download_and_convert_hf_checkpoint(
679+
repo_id, convert_weights
680+
)
674681

675682
if config.misc.profile_path is not None:
676683
try:
@@ -711,9 +718,7 @@ def _prepare_for_llama_export(config: DictConfig) -> LLMEdgeManager:
711718
if config.model.checkpoint_dir
712719
else None
713720
)
714-
params_path = (
715-
canonical_path(config.model.params) if config.model.params else None
716-
)
721+
params_path = canonical_path(config.model.params) if config.model.params else None
717722
output_dir_path = canonical_path(config.export.output_dir, dir=True)
718723
weight_type = (
719724
WeightType.FAIRSEQ2 if config.model.type == "FAIRSEQ2" else WeightType.LLAMA

examples/models/llama/source_transformation/quantize.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,9 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
783783
############################ Source Transform Start #######################
784784

785785

786-
def get_quant_embedding_transform(config: DictConfig, dtype_override: Optional[DType] = None):
786+
def get_quant_embedding_transform(
787+
config: DictConfig, dtype_override: Optional[DType] = None
788+
):
787789
if config.quantization.embedding_quantize.startswith("torchao:"):
788790
from torchao.experimental.quant_api import (
789791
EmbeddingQuantizer,
@@ -850,15 +852,15 @@ def get_quant_weight_transform(
850852
# If these optional args are None, don't provide them to quantize().
851853
quant_args = {}
852854
if config.quantization.group_size is not None:
853-
quant_args['group_size'] = config.quantization.group_size
855+
quant_args["group_size"] = config.quantization.group_size
854856
if config.calibration.tasks is not None:
855-
quant_args['calibration_tasks'] = OmegaConf.to_container(config.calibration.tasks)
857+
quant_args["calibration_tasks"] = OmegaConf.to_container(
858+
config.calibration.tasks
859+
)
856860
if config.calibration.limit is not None:
857-
quant_args['calibration_limit'] = config.calibration.limit
861+
quant_args["calibration_limit"] = config.calibration.limit
858862
if config.calibration.seq_length is not None:
859-
quant_args['calibration_seq_length'] = config.calibration.seq_length
860-
861-
863+
quant_args["calibration_seq_length"] = config.calibration.seq_length
862864

863865
group_size = config.quantization.group_size
864866
calibration_tasks = config.calibration.tasks
@@ -871,11 +873,15 @@ def get_quant_weight_transform(
871873
qmode=config.quantization.mode,
872874
computation_dtype=computation_dtype,
873875
checkpoint_dtype=checkpoint_dtype,
874-
checkpoint_path=(Path(path) if (path := config.model.checkpoint) is not None else None),
876+
checkpoint_path=(
877+
Path(path) if (path := config.model.checkpoint) is not None else None
878+
),
875879
tokenizer_path=(
876880
Path(path) if (path := config.model.tokenizer_path) is not None else None
877881
),
878882
)
883+
884+
879885
def _load_torchao_aten_lib(libname):
880886
import glob
881887
import os

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
import unittest
89

910
from executorch.devtools.backend_debug import get_delegation_info
1011
from executorch.examples.models.llama.export_llama_lib import (
11-
export_llama,
12-
build_args_parser,
12+
_export_llama,
13+
get_default_llm_config,
1314
)
1415

16+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
17+
logging.basicConfig(level=logging.INFO, format=FORMAT, force=True)
18+
1519
UNWANTED_OPS = [
1620
"aten_permute_copy_default",
1721
"aten_transpose_copy_default",
@@ -34,13 +38,12 @@ def test_has_expected_ops_and_op_counts(self):
3438
# we cannot test quantization args in this way
3539
# since quantization requires promoting meta tensors
3640
# to device=cpu, which requires real weights.
37-
parser = build_args_parser()
38-
args = parser.parse_args([])
39-
args.use_sdpa_with_kv_cache = True
40-
args.use_kv_cache = True
41-
args.verbose = True
41+
export_config = get_default_llm_config()
42+
export_config.kv_cache.use_sdpa_with_kv_cache = True
43+
export_config.kv_cache.use_kv_cache = True
44+
export_config.misc.verbose = True
4245

43-
builder = export_llama(args)
46+
builder = _export_llama(export_config)
4447
graph_module = builder.edge_manager.exported_program().graph_module
4548
delegation_info = get_delegation_info(graph_module)
4649

extension/llm/export/builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
203203
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204204
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
205205
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
206-
if hasattr(self.args, "qnn") and self.args.qnn:
206+
if self.args.backend.qnn.enabled:
207207
# TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
208208
# See issue: https://github.com/pytorch/executorch/issues/7373
209209

@@ -249,8 +249,8 @@ def export(self) -> "LLMEdgeManager":
249249
# Persisting those changes back to an ExportedProgram will require
250250
# an additional export().
251251
self.pre_autograd_graph_module = exported_module.module()
252-
if hasattr(self.args, "export_only") and self.args.export_only:
253-
torch.export.save(exported_module, self.args.output_name)
252+
if self.args.export.export_only:
253+
torch.export.save(exported_module, self.args.export.output_name)
254254
return self
255255

256256
def run_canonical_optimizations(self):
@@ -414,7 +414,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
414414
self.export()
415415

416416
override_export_behaviour = contextlib.nullcontext()
417-
if hasattr(self.args, "qnn") and self.args.qnn:
417+
if self.args.backend.qnn.enabled:
418418
override_export_behaviour = patch.object(
419419
torch._utils_internal,
420420
"export_training_ir_rollout_check",

0 commit comments

Comments
 (0)