9
9
10
10
import argparse
11
11
from itertools import islice
12
- from typing import Any , Iterator , Tuple
12
+ from typing import Any , Dict , Iterator , Optional , Tuple
13
13
14
14
import cv2
15
15
import executorch
28
28
to_edge_transform_and_lower ,
29
29
)
30
30
from executorch .exir .backend .backend_details import CompileSpec
31
+ from executorch .runtime import Runtime
31
32
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
32
33
from torch .export .exported_program import ExportedProgram
33
34
from torch .fx .passes .graph_drawer import FxGraphDrawer
34
35
from ultralytics import YOLO
35
36
37
+ from ultralytics .data .utils import check_det_dataset
38
+ from ultralytics .engine .validator import BaseValidator as Validator
39
+ from ultralytics .utils .torch_utils import de_parallel
40
+
36
41
37
42
class CV2VideoIter :
38
43
def __init__ (self , cap ) -> None :
@@ -204,17 +209,21 @@ def main(
204
209
subset_size : int ,
205
210
backend : str ,
206
211
device : str ,
212
+ val_dataset_yaml_path : Optional [str ],
207
213
):
208
214
"""
209
215
Main function to load, quantize, and export an Yolo model model.
210
216
211
217
:param model_name: The name of the YOLO model to load.
218
+ :param input_dims: Input dims to use for the export of a YOLO12 model.
212
219
:param quantize: Whether to quantize the model.
213
220
:param video_path: Path to the video to use for the calibration
221
+ :param subset_size: Subset size for the quantized model calibration. The default value is 300.
214
222
:param backend: The Executorch inference backend (e.g., "openvino", "xnnpack").
215
223
:param device: The device to run the model on (e.g., "cpu", "gpu").
224
+ :param val_dataset_yaml_path: Path to the validation dataset file in Ultralytics .yaml format.
225
+ Performs validation if the path is not None, skips validation otherwise.
216
226
"""
217
-
218
227
# Load the selected model
219
228
model = YOLO (model_name )
220
229
@@ -267,6 +276,67 @@ def transform_fn(frame):
267
276
exec_prog .write_to_file (file )
268
277
print (f"Model exported and saved as { model_file_name } on { device } ." )
269
278
279
+ if val_dataset_yaml_path is not None :
280
+ if input_dims != [640 , 640 ]:
281
+ raise NotImplementedError (
282
+ f"Validation with the custom input shape { input_dims } is not implmenented."
283
+ " Please use the default --input_dims=[640, 640] for the validation."
284
+ )
285
+ stats = validate_yolo (model , exec_prog , val_dataset_yaml_path )
286
+ for stat , value in stats .items ():
287
+ print (f"{ stat } : { value } " )
288
+
289
+
290
+ def _prepare_validation (
291
+ model : YOLO , dataset_yaml_path : str
292
+ ) -> Tuple [Validator , torch .utils .data .DataLoader ]:
293
+ custom = {"rect" : False , "batch" : 1 } # method defaults
294
+ args = {
295
+ ** model .overrides ,
296
+ ** custom ,
297
+ "mode" : "val" ,
298
+ } # highest priority args on the right
299
+
300
+ validator = model ._smart_load ("validator" )(args = args , _callbacks = model .callbacks )
301
+ stride = 32 # default stride
302
+ validator .stride = stride # used in get_dataloader() for padding
303
+ validator .data = check_det_dataset (dataset_yaml_path )
304
+ validator .init_metrics (de_parallel (model ))
305
+
306
+ data_loader = validator .get_dataloader (
307
+ validator .data .get (validator .args .split ), validator .args .batch
308
+ )
309
+
310
+ return validator , data_loader
311
+
312
+
313
+ def validate_yolo (
314
+ model : YOLO , exec_prog : ExecutorchProgramManager , dataset_yaml_path : str
315
+ ) -> Dict [str , float ]:
316
+ """
317
+ Runs validation on a YOLO model using an ExecuTorch program and a dataset in Ultralytics format.
318
+
319
+ :param model: The YOLO model instance to validate.
320
+ :param exec_prog: The ExecuTorch program manager containing the compiled model.
321
+ :param dataset_yaml_path: Path to the validation dataset file in Ultralytics .yaml format.
322
+ :return: Dictionary of validation statistics computed over the dataset.
323
+ """
324
+ # Load model from buffer
325
+ runtime = Runtime .get ()
326
+ program = runtime .load_program (exec_prog .buffer )
327
+ method = program .load_method ("forward" )
328
+ if method is None :
329
+ raise ValueError ("Load method failed" )
330
+ validator , data_loader = _prepare_validation (model , dataset_yaml_path )
331
+ print (f"Start validation on { dataset_yaml_path } dataset ..." )
332
+ for batch in data_loader :
333
+ batch = validator .preprocess (batch )
334
+ preds = method .execute ((batch ["img" ],))
335
+ preds = validator .postprocess (preds )
336
+ validator .update_metrics (preds , batch )
337
+ stats = validator .get_stats ()
338
+ return stats
339
+
270
340
271
341
if __name__ == "__main__" :
272
342
parser = argparse .ArgumentParser (
@@ -312,6 +382,13 @@ def transform_fn(frame):
312
382
default = "CPU" ,
313
383
help = "Target device for compiling the model (e.g., CPU, GPU). Default is CPU." ,
314
384
)
385
+ parser .add_argument (
386
+ "--validate" ,
387
+ nargs = "?" ,
388
+ const = "coco128.yaml" ,
389
+ help = "Validate executorch model using the Ultralytics validation pipeline."
390
+ " Default validateion dataset is coco128.yaml." ,
391
+ )
315
392
316
393
args = parser .parse_args ()
317
394
@@ -320,6 +397,7 @@ def transform_fn(frame):
320
397
model_name = args .model_name ,
321
398
input_dims = args .input_dims ,
322
399
quantize = args .quantize ,
400
+ val_dataset_yaml_path = args .validate ,
323
401
video_path = args .video_path ,
324
402
subset_size = args .subset_size ,
325
403
backend = args .backend ,
0 commit comments