13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import Dict , Union
16
+ from itertools import product
17
+ from typing import Any , Dict , Union
17
18
18
19
from tabulate import tabulate
19
20
from torch .export import ExportedProgram
34
35
from ..recipe_registry import register_recipe
35
36
36
37
37
- @register_recipe ("coreml" )
38
- def export_to_executorch_with_coreml (
38
+ def _export_to_executorch (
39
39
model : Union [CausalLMExportableModule , MaskedLMExportableModule , Seq2SeqLMExportableModule ],
40
40
** kwargs ,
41
41
):
@@ -63,23 +63,14 @@ def export_to_executorch_with_coreml(
63
63
64
64
def _lower_to_executorch (
65
65
exported_programs : Dict [str , ExportedProgram ],
66
- metadata = None ,
67
- ** kwargs ,
66
+ metadata ,
67
+ compute_unit ,
68
+ minimum_deployment_target ,
69
+ compute_precision ,
68
70
) -> Dict [str , ExecutorchProgram ]:
69
- compute_unit = kwargs .get ("compute_unit" , ct .ComputeUnit .ALL )
70
- minimum_deployment_target = kwargs .get ("minimum_deployment_target" , ct .target .iOS15 )
71
- compute_precision = kwargs .get ("compute_precision" , ct .precision .FLOAT16 )
72
- model_type = kwargs .get ("model_type" , "model" )
73
- model_type = {
74
- "model" : CoreMLBackend .MODEL_TYPE .MODEL ,
75
- "modelc" : CoreMLBackend .MODEL_TYPE .COMPILED_MODEL ,
76
- }[model_type ]
77
- take_over_mutable_buffer = kwargs .get ("take_over_mutable_buffer" , True )
78
-
79
71
et_progs = {}
80
72
backend_config_dict = {}
81
73
for pte_name , exported_program in exported_programs .items ():
82
- exported_program = exported_program .run_decompositions ({})
83
74
logging .debug (f"\n Exported program for { pte_name } .pte: { exported_program } " )
84
75
et_progs [pte_name ] = to_edge_transform_and_lower (
85
76
exported_program ,
@@ -89,14 +80,15 @@ def _lower_to_executorch(
89
80
compute_unit = compute_unit ,
90
81
minimum_deployment_target = minimum_deployment_target ,
91
82
compute_precision = compute_precision ,
92
- model_type = model_type ,
83
+ model_type = CoreMLBackend . MODEL_TYPE . MODEL ,
93
84
),
94
- take_over_mutable_buffer = take_over_mutable_buffer ,
85
+ take_over_mutable_buffer = ( minimum_deployment_target >= ct . target . iOS18 ) ,
95
86
)
96
87
],
97
88
compile_config = EdgeCompileConfig (
98
89
_check_ir_validity = False ,
99
- _skip_dim_order = False ,
90
+ # In ET 0.7, we can set _skip_dim_order=False
91
+ _skip_dim_order = True ,
100
92
),
101
93
constant_methods = metadata ,
102
94
).to_executorch (
@@ -114,3 +106,46 @@ def _lower_to_executorch(
114
106
115
107
exported_progs = model .export ()
116
108
return _lower_to_executorch (exported_progs , model .metadata , ** kwargs )
109
+
110
+
111
+ def _get_recipe_kwargs (dtype : str , compute_unit : str ) -> Dict [str , Any ]:
112
+ import coremltools as ct
113
+
114
+ compute_precision = {
115
+ "fp16" : ct .precision .FLOAT16 ,
116
+ "fp32" : ct .precision .FLOAT32 ,
117
+ }[dtype ]
118
+
119
+ compute_unit = {
120
+ "cpu" : ct .ComputeUnit .CPU_ONLY ,
121
+ "gpu" : ct .ComputeUnit .CPU_AND_GPU ,
122
+ "ne" : ct .ComputeUnit .CPU_AND_NE ,
123
+ "all" : ct .ComputeUnit .ALL ,
124
+ }[compute_unit ]
125
+
126
+ recipe_kwargs = {
127
+ "compute_precision" : compute_precision ,
128
+ "compute_unit" : compute_unit ,
129
+ "minimum_deployment_target" : ct .target .iOS18 ,
130
+ }
131
+ return recipe_kwargs
132
+
133
+
134
+ def _make_recipe (recipe_name , recipe_kwargs ):
135
+ @register_recipe (recipe_name )
136
+ def recipe_fn (exported_programs : Dict [str , ExportedProgram ], ** kwargs ):
137
+ return _export_to_executorch (
138
+ exported_programs ,
139
+ ** recipe_kwargs ,
140
+ )
141
+
142
+ return recipe_fn
143
+
144
+
145
+ # Register recipes for CoreML backend
146
+ for dtype , compute_unit in product (["fp32" , "fp16" ], ["cpu" , "gpu" , "ne" , "all" ]):
147
+ recipe_name = f"coreml_{ dtype } "
148
+ if compute_unit != "all" :
149
+ recipe_name += f"_{ compute_unit } "
150
+ recipe_kwargs = _get_recipe_kwargs (dtype = dtype , compute_unit = compute_unit )
151
+ _make_recipe (recipe_name , recipe_kwargs )
0 commit comments