|
21 | 21 | DuplicateDynamicQuantChainPass,
|
22 | 22 | )
|
23 | 23 | from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
|
| 24 | + |
24 | 25 | from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
|
25 | 26 | from executorch.exir.backend.partitioner import Partitioner
|
26 | 27 |
|
|
33 | 34 | from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
|
34 | 35 |
|
35 | 36 | from executorch.extension.export_util.utils import export_to_edge, save_pte_program
|
36 |
| - |
37 | 37 | from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
|
38 | 38 | from executorch.extension.llm.tokenizer.utils import get_tokenizer
|
| 39 | +from omegaconf import DictConfig |
39 | 40 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
40 | 41 | from torch.ao.quantization.quantizer import Quantizer
|
41 | 42 | from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
|
@@ -87,7 +88,7 @@ def __init__(
|
87 | 88 | use_kv_cache,
|
88 | 89 | example_inputs,
|
89 | 90 | example_kwarg_inputs: Optional[Dict] = None,
|
90 |
| - args: Optional[Any] = None, |
| 91 | + config: Optional[DictConfig] = None, |
91 | 92 | enable_dynamic_shape: bool = False,
|
92 | 93 | generate_full_logits: bool = False,
|
93 | 94 | calibration_tasks: Optional[List[str]] = None,
|
@@ -121,7 +122,7 @@ def __init__(
|
121 | 122 | self.output_dir = "."
|
122 | 123 | self.dynamic_shapes = dynamic_shapes
|
123 | 124 | self._saved_pte_filename = None
|
124 |
| - self.args = args |
| 125 | + self.config = config |
125 | 126 | self.calibration_tasks = calibration_tasks
|
126 | 127 | self.calibration_limit = calibration_limit
|
127 | 128 | self.calibration_seq_length = calibration_seq_length
|
@@ -203,7 +204,7 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
|
203 | 204 | # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
|
204 | 205 | # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
|
205 | 206 | with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
|
206 |
| - if self.args.backend.qnn.enabled: |
| 207 | + if self.config.backend.qnn.enabled: |
207 | 208 | # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
|
208 | 209 | # See issue: https://github.com/pytorch/executorch/issues/7373
|
209 | 210 |
|
@@ -249,8 +250,8 @@ def export(self) -> "LLMEdgeManager":
|
249 | 250 | # Persisting those changes back to an ExportedProgram will require
|
250 | 251 | # an additional export().
|
251 | 252 | self.pre_autograd_graph_module = exported_module.module()
|
252 |
| - if self.args.export.export_only: |
253 |
| - torch.export.save(exported_module, self.args.export.output_name) |
| 253 | + if self.config.export.export_only: |
| 254 | + torch.export.save(exported_module, self.config.export.output_name) |
254 | 255 | return self
|
255 | 256 |
|
256 | 257 | def run_canonical_optimizations(self):
|
@@ -414,7 +415,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
|
414 | 415 | self.export()
|
415 | 416 |
|
416 | 417 | override_export_behaviour = contextlib.nullcontext()
|
417 |
| - if self.args.backend.qnn.enabled: |
| 418 | + if self.config.backend.qnn.enabled: |
418 | 419 | override_export_behaviour = patch.object(
|
419 | 420 | torch._utils_internal,
|
420 | 421 | "export_training_ir_rollout_check",
|
|
0 commit comments