diff --git a/.ci/scripts/test_yolo12.sh b/.ci/scripts/test_yolo12.sh new file mode 100755 index 00000000000..0a7c6273056 --- /dev/null +++ b/.ci/scripts/test_yolo12.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex +# shellcheck source=/dev/null +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +while [[ $# -gt 0 ]]; do + case "$1" in + -model) + MODEL_NAME="$2" # stories110M + shift 2 + ;; + -mode) + MODE="$2" # portable or xnnpack+custom or xnnpack+custom+qe + shift 2 + ;; + -pt2e_quantize) + PT2E_QUANTIZE="$2" + shift 2 + ;; + -upload) + UPLOAD_DIR="$2" + shift 2 + ;; + -video_path) + VIDEO_PATH="$2" # portable or xnnpack+custom or xnnpack+custom+qe + shift 2 + ;; + *) + echo "Unknown option: $1" + usage + ;; + esac +done + +# Default mode to xnnpack+custom if not set +MODE=${MODE:-"openvino"} + +# Default UPLOAD_DIR to empty string if not set +UPLOAD_DIR="${UPLOAD_DIR:-}" + +# Default PT2E_QUANTIZE to empty string if not set +PT2E_QUANTIZE="${PT2E_QUANTIZE:-}" + +# Default CMake Build Type to release mode +CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} + +if [[ $# -lt 5 ]]; then # Assuming 4 mandatory args + echo "Expecting atleast 5 positional arguments" + echo "Usage: [...]" +fi +if [[ -z "${MODEL_NAME:-}" ]]; then + echo "Missing model name, exiting..." + exit 1 +fi + + +if [[ -z "${MODE:-}" ]]; then + echo "Missing mode, choose openvino or xnnpack, exiting..." + exit 1 +fi + +if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then + PYTHON_EXECUTABLE=python3 +fi + +TARGET_LIBS="" + +if [[ "${MODE}" =~ .*openvino.* ]]; then + OPENVINO=ON + TARGET_LIBS="$TARGET_LIBS openvino_backend " + + git clone https://github.com/daniil-lyakhov/openvino.git + + cd openvino && git checkout dl/executorch/yolo12 + git submodule update --init --recursive + sudo ./install_build_dependencies.sh + mkdir build && cd build + cmake .. -DCMAKE_BUILD_TYPE=Release -DENABLE_PYTHON=ON + make -j$(nproc) + + cd .. + cmake --install build --prefix dist + + source dist/setupvars.sh + cd ../backends/openvino + pip install -r requirements.txt + cd ../../ +else + OPENVINO=OFF +fi + +if [[ "${MODE}" =~ .*xnnpack.* ]]; then + XNNPACK=ON + TARGET_LIBS="$TARGET_LIBS xnnpack_backend " +else + XNNPACK=OFF +fi + +which "${PYTHON_EXECUTABLE}" + + +DIR="examples/models/yolo12" +$PYTHON_EXECUTABLE -m pip install -r ${DIR}/requirements.txt + +cmake_install_executorch_libraries() { + rm -rf cmake-out + build_dir=cmake-out + mkdir $build_dir + + + retry cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" \ + -DEXECUTORCH_BUILD_OPENVINO="$OPENVINO" \ + -DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -B"${build_dir}" + + # Build the project + cmake --build ${build_dir} --target install --config ${CMAKE_BUILD_TYPE} -j$(nproc) + + export CMAKE_ARGS=" + -DEXECUTORCH_BUILD_OPENVINO="$OPENVINO" \ + -DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_PYBIND=ON" + + echo $TARGET_LIBS + export CMAKE_BUILD_ARGS="--target $TARGET_LIBS" + pip install . --no-build-isolation +} + +cmake_build_demo() { + echo "Building yolo12 runner" + retry cmake \ + -DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \ + -DUSE_OPENVINO_BACKEND="$OPENVINO" \ + -DUSE_XNNPACK_BACKEND="$XNNPACK" \ + -Bcmake-out/${DIR} \ + ${DIR} + cmake --build cmake-out/${DIR} -j9 --config "$CMAKE_BUILD_TYPE" + +} + +cleanup_files() { + rm $EXPORTED_MODEL_NAME +} + +prepare_artifacts_upload() { + if [ -n "${UPLOAD_DIR}" ]; then + echo "Preparing for uploading generated artifacs" + zip -j model.zip "${EXPORTED_MODEL_NAME}" + mkdir -p "${UPLOAD_DIR}" + mv model.zip "${UPLOAD_DIR}" + mv result.txt "${UPLOAD_DIR}" + + fi +} + + +# Export model. +EXPORTED_MODEL_NAME="${MODEL_NAME}_fp32_${MODE}.pte" +echo "Exporting ${EXPORTED_MODEL_NAME}" +EXPORT_ARGS="--model_name=${MODEL_NAME} --backend=${MODE}" + +# Add dynamically linked library location +cmake_install_executorch_libraries + +$PYTHON_EXECUTABLE -m examples.models.yolo12.export_and_validate ${EXPORT_ARGS} + + +RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --input_path=${VIDEO_PATH}" +# Check build tool. +cmake_build_demo +# Run yolo12 runner +NOW=$(date +"%H:%M:%S") +echo "Starting to run yolo12 runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/yolo12/Yolo12DetectionDemo ${RUNTIME_ARGS} > result.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" + +RESULT=$(cat result.txt) + +prepare_artifacts_upload +cleanup_files diff --git a/backends/openvino/README.md b/backends/openvino/README.md index 95a5f4c364e..a12a64746af 100644 --- a/backends/openvino/README.md +++ b/backends/openvino/README.md @@ -43,8 +43,8 @@ executorch Before you begin, ensure you have openvino installed and configured on your system: ```bash -git clone https://github.com/openvinotoolkit/openvino.git -cd openvino && git checkout releases/2025/1 +git clone https://github.com/daniil-lyakhov/openvino.git +cd openvino && git checkout dl/executorch/yolo12 git submodule update --init --recursive sudo ./install_build_dependencies.sh mkdir build && cd build diff --git a/examples/models/yolo12/CMakeLists.txt b/examples/models/yolo12/CMakeLists.txt new file mode 100644 index 00000000000..5d63bd39ad1 --- /dev/null +++ b/examples/models/yolo12/CMakeLists.txt @@ -0,0 +1,84 @@ +cmake_minimum_required(VERSION 3.5) + +project(Yolo12DetectionDemo VERSION 0.1) + +option(USE_OPENVINO_BACKEND "Build the tutorial with the OPENVINO backend" ON) +option(USE_XNNPACK_BACKEND "Build the tutorial with the XNNPACK backend" OFF) + +set(CMAKE_INCLUDE_CURRENT_DIR ON) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# OpenCV +find_package(OpenCV REQUIRED) +include_directories(${OpenCV_INCLUDE_DIRS}) +# !OpenCV + +if(NOT PYTHON_EXECUTABLE) + set(PYTHON_EXECUTABLE python3) +endif() + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +# Let files say "include ". +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# find `executorch` libraries Same as for gflags +find_package(executorch CONFIG REQUIRED PATHS ${EXECUTORCH_ROOT}/cmake-out) +target_link_options_shared_lib(executorch) + +set(link_libraries gflags) +list(APPEND link_libraries portable_ops_lib portable_kernels) +target_link_options_shared_lib(portable_ops_lib) + + +if(USE_XNNPACK_BACKEND) + set(xnnpack_backend_libs xnnpack_backend XNNPACK microkernels-prod) + list(APPEND link_libraries ${xnnpack_backend_libs}) + target_link_options_shared_lib(xnnpack_backend) +endif() + +if(USE_OPENVINO_BACKEND) + add_subdirectory(${EXECUTORCH_ROOT}/backends/openvino openvino_backend) + + target_include_directories( + openvino_backend + INTERFACE ${CMAKE_CURRENT_BINARY_DIR}/../../include + ${CMAKE_CURRENT_BINARY_DIR}/../../include/executorch/runtime/core/portable_type/c10 + ${CMAKE_CURRENT_BINARY_DIR}/../../lib + ) + list(APPEND link_libraries openvino_backend) + target_link_options_shared_lib(openvino_backend) +endif() + +list(APPEND link_libraries extension_threadpool pthreadpool) +list(APPEND _common_include_directories + ${XNNPACK_ROOT}/third-party/pthreadpool/include +) + +set(PROJECT_SOURCES + main.cpp + inference.h + ${EXECUTORCH_ROOT}/extension/data_loader/file_data_loader.cpp + ${EXECUTORCH_ROOT}/extension/evalue_util/print_evalue.cpp + ${EXECUTORCH_ROOT}/extension/runner_util/inputs.cpp + ${EXECUTORCH_ROOT}/extension/runner_util/inputs_portable.cpp +) + +add_executable(Yolo12DetectionDemo ${PROJECT_SOURCES}) +target_link_libraries(Yolo12DetectionDemo PUBLIC + ${link_libraries} + ${OpenCV_LIBS} + executorch_core + extension_module + extension_tensor +) + +find_package(Threads REQUIRED) +target_link_libraries(Yolo12DetectionDemo PRIVATE Threads::Threads) +target_include_directories(Yolo12DetectionDemo PUBLIC ${_common_include_directories}) \ No newline at end of file diff --git a/examples/models/yolo12/README.md b/examples/models/yolo12/README.md new file mode 100644 index 00000000000..c92d8244feb --- /dev/null +++ b/examples/models/yolo12/README.md @@ -0,0 +1,103 @@ +# YOLO12 Detection C++ Inference with ExecuTorch + +

+
+ +
+

+ +This example demonstrates how to perform inference of [Ultralytics YOLO12 family](https://docs.ultralytics.com/models/yolo12/) detection models in C++ leveraging the Executorch backends: +- [OpenVINO](../../../backends/openvino/README.md) +- [XNNPACK](../../../backends/xnnpack/README.md) + + +# Instructions + +### Step 1: Install ExecuTorch + +To install ExecuTorch, follow this [guide](https://pytorch.org/executorch/stable/getting-started-setup.html). + +### Step 2: Install the backend of your choice + +- [OpenVINO backend installation guide](../../../backends/openvino/README.md#build-instructions) +- [XNNPACK backend installation guilde](https://pytorch.org/executorch/stable/tutorial-xnnpack-delegate-lowering.html#running-the-xnnpack-model-with-cmake) + +### Step 3: Install the demo requirements + + +Python demo requirements: +```bash +python -m pip install -r examples/models/yolo12/requirements.txt +``` + +Demo infenrece dependency - OpenCV library: +https://opencv.org/get-started/ + + +### Step 4: Export the Yolo12 model to the ExecuTorch + + +OpenVINO: +```bash +python export_and_validate.py --model_name yolo12s --input_dims=[1920,1080] --backend openvino --device CPU +``` + +XNNPACK: +```bash +python export_and_validate.py --model_name yolo12s --input_dims=[1920,1080] --backend xnnpack +``` + +> **_NOTE:_** Quantization is comming soon! + +Exported model could be validated using the `--validate` key: + +```bash +python export_and_validate.py --model_name yolo12s --backend ... --validate dataset_name.yaml +``` + +A list of available datasets and instructions on how to use a custom dataset can be found [here](https://docs.ultralytics.com/datasets/detect/). +Validation only supports the default `--input_dims`; please do not specify this parameter when using the `--validate` flag. + + +To get a full parameters description please use the following command: +```bash +python export_and_validate.py --help +``` + +### Step 5: Build the demo project + +OpenVINO: + +```bash +cd examples/models/yolo12 +mkdir build && cd build +cmake -DCMAKE_BUILD_TYPE=Release -DUSE_OPENVINO_BACKEND=ON .. +make -j$(nproc) +``` + +XNNPACK: + +```bash +cd examples/models/yolo12 +mkdir build && cd build +cmake -DCMAKE_BUILD_TYPE=Release -DUSE_XNNPACK_BACKEND=ON .. +make -j$(nproc) +``` + +### Step 6: Run the demo + +```bash +./build/Yolo12DetectionDemo -model_path /path/to/exported/model -input_path /path/to/video/file -output_path /path/to/output/annotated/video +``` + +To get a full parameters description please use the following command: +``` +./build/Yolo12DetectionDemo --help +``` + + +# Credits: + +Ultralytics examples: https://github.com/ultralytics/ultralytics/tree/main/examples + +Sample video: https://www.pexels.com/@shanu-1040189/ diff --git a/examples/models/yolo12/export_and_validate.py b/examples/models/yolo12/export_and_validate.py new file mode 100644 index 00000000000..370763ef800 --- /dev/null +++ b/examples/models/yolo12/export_and_validate.py @@ -0,0 +1,405 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code="import-untyped,import-not-found" + + +import argparse +from itertools import islice +from typing import Any, Dict, Iterator, Optional, Tuple + +import cv2 +import executorch +import numpy as np +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + ExecutorchBackendConfig, + ExecutorchProgramManager, + to_edge_transform_and_lower, +) +from executorch.exir.backend.backend_details import CompileSpec +from executorch.runtime import Runtime +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.graph_drawer import FxGraphDrawer +from ultralytics import YOLO + +from ultralytics.data.utils import check_det_dataset +from ultralytics.engine.validator import BaseValidator as Validator +from ultralytics.utils.torch_utils import de_parallel + + +class CV2VideoIter: + def __init__(self, cap) -> None: + self._cap = cap + + def __iter__(self): + return self + + def __next__(self): + success, frame = self._cap.read() + if not success: + raise StopIteration() + return frame + + def __len__(self): + return int(self._cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + +class CV2VideoDataset(torch.utils.data.IterableDataset): + def __init__(self, cap) -> None: + super().__init__() + self._iter = CV2VideoIter(cap) + + def __iter__(self) -> Iterator: + return self._iter + + def __len__(self): + return len(self._iter) + + +def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str): + g = FxGraphDrawer(model, output_svg_path) + g.get_dot_graph().write_svg(output_svg_path) + + +def lower_to_openvino( + aten_dialect: ExportedProgram, + example_args: Tuple[Any, ...], + transform_fn: callable, + device: str, + calibration_dataset: CV2VideoDataset, + subset_size: int, + quantize: bool, +) -> ExecutorchProgramManager: + # Import openvino locally to avoid nncf side-effects + import nncf.torch + from executorch.backends.openvino.partitioner import OpenvinoPartitioner + from executorch.backends.openvino.quantizer import OpenVINOQuantizer + from executorch.backends.openvino.quantizer.quantizer import QuantizationMode + from nncf.experimental.torch.fx import quantize_pt2e + + with nncf.torch.disable_patching(): + if quantize: + target_input_dims = tuple(example_args[0].shape[2:]) + + def ext_transform_fn(sample): + sample = transform_fn(sample) + return pad_to_target(sample, target_input_dims) + + quantizer = OpenVINOQuantizer(mode=QuantizationMode.INT8_TRANSFORMER) + quantizer.set_ignored_scope( + types=["mul", "sub", "sigmoid", "__getitem__"], + ) + quantized_model = quantize_pt2e( + aten_dialect.module(), + quantizer, + nncf.Dataset(calibration_dataset, ext_transform_fn), + subset_size=subset_size, + smooth_quant=True, + fold_quantize=False, + ) + + visualize_fx_model(quantized_model, "tmp_quantized_model.svg") + aten_dialect = torch.export.export(quantized_model, example_args) + # Convert to edge dialect and lower the module to the backend with a custom partitioner + compile_spec = [CompileSpec("device", device.encode())] + lowered_module: EdgeProgramManager = to_edge_transform_and_lower( + aten_dialect, + partitioner=[ + OpenvinoPartitioner(compile_spec), + ], + compile_config=EdgeCompileConfig( + _skip_dim_order=True, + ), + ) + + # Apply backend-specific passes + return lowered_module.to_executorch( + config=executorch.exir.ExecutorchBackendConfig() + ) + + +def lower_to_xnnpack( + aten_dialect: ExportedProgram, + example_args: Tuple[Any, ...], + transform_fn: callable, + device: str, + calibration_dataset: CV2VideoDataset, + subset_size: int, + quantize: bool, +) -> ExecutorchProgramManager: + if quantize: + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=False, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + m = prepare_pt2e(aten_dialect.module(), quantizer) + # calibration + target_input_dims = tuple(example_args[0].shape[2:]) + print("Start quantization...") + for sample in islice(calibration_dataset, subset_size): + sample = transform_fn(sample) + sample = pad_to_target(sample, target_input_dims) + m(sample) + m = convert_pt2e(m) + print("Quantized succsessfully!") + aten_dialect = torch.export.export(m, example_args) + + edge = to_edge_transform_and_lower( + aten_dialect, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False if args.quantize else True, + _skip_dim_order=True, # TODO(T182187531): enable dim order in xnnpack + ), + ) + + return edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + +def pad_to_target( + image: torch.Tensor, + target_size: Tuple[int, int], +): + if image.shape[2:] == target_size: + return image + img_h, img_w = image.shape[2:] + target_h, target_w = target_size + + diff_h = target_h - img_h + pad_h_from = diff_h // 2 + pad_h_to = -(pad_h_from + diff_h % 2) or None + diff_w = target_w - img_w + pad_w_from = diff_w // 2 + pad_w_to = -(pad_w_from + diff_w % 2) or None + + result = torch.zeros( + ( + 1, + 3, + ) + + target_size, + device=image.device, + dtype=image.dtype, + ) + result[:, :, pad_h_from:pad_h_to, pad_w_from:pad_w_to] = image + return result + + +def main( + model_name: str, + input_dims: Tuple[int, int], + quantize: bool, + video_path: str, + subset_size: int, + backend: str, + device: str, + val_dataset_yaml_path: Optional[str], +): + """ + Main function to load, quantize, and export an Yolo model model. + + :param model_name: The name of the YOLO model to load. + :param input_dims: Input dims to use for the export of a YOLO12 model. + :param quantize: Whether to quantize the model. + :param video_path: Path to the video to use for the calibration + :param subset_size: Subset size for the quantized model calibration. The default value is 300. + :param backend: The Executorch inference backend (e.g., "openvino", "xnnpack"). + :param device: The device to run the model on (e.g., "cpu", "gpu"). + :param val_dataset_yaml_path: Path to the validation dataset file in Ultralytics .yaml format. + Performs validation if the path is not None, skips validation otherwise. + """ + # Load the selected model + model = YOLO(model_name) + + if quantize: + raise NotImplementedError("Quantization is comming soon!") + if video_path is None: + raise RuntimeError( + "Could not quantize model without the video for the calibration." + " --video_path parameter is needed." + ) + cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + print(f"Calibration video dims: h: {height} w: {width}") + calibration_dataset = CV2VideoDataset(cap) + else: + calibration_dataset = None + + # Setup pre-processing + np_dummy_tensor = np.ones((input_dims[0], input_dims[1], 3)) + model.predict(np_dummy_tensor, imgsz=((input_dims[0], input_dims[1])), device="cpu") + + pt_model = model.model.to(torch.device("cpu")) + + def transform_fn(frame): + input_tensor = model.predictor.preprocess([frame]) + return input_tensor + + example_args = (transform_fn(np_dummy_tensor),) + with torch.no_grad(): + aten_dialect = torch.export.export(pt_model, args=example_args) + + if backend == "openvino": + lower_fn = lower_to_openvino + elif backend == "xnnpack": + lower_fn = lower_to_xnnpack + + exec_prog = lower_fn( + aten_dialect=aten_dialect, + example_args=example_args, + transform_fn=transform_fn, + device=device, + calibration_dataset=calibration_dataset, + subset_size=subset_size, + quantize=quantize, + ) + + model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}_{backend}.pte" + with open(model_file_name, "wb") as file: + exec_prog.write_to_file(file) + print(f"Model exported and saved as {model_file_name} on {device}.") + + if val_dataset_yaml_path is not None: + if input_dims != [640, 640]: + raise NotImplementedError( + f"Validation with the custom input shape {input_dims} is not implmenented." + " Please use the default --input_dims=[640, 640] for the validation." + ) + stats = validate_yolo(model, exec_prog, val_dataset_yaml_path) + for stat, value in stats.items(): + print(f"{stat}: {value}") + + +def _prepare_validation( + model: YOLO, dataset_yaml_path: str +) -> Tuple[Validator, torch.utils.data.DataLoader]: + custom = {"rect": False, "batch": 1} # method defaults + args = { + **model.overrides, + **custom, + "mode": "val", + } # highest priority args on the right + + validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks) + stride = 32 # default stride + validator.stride = stride # used in get_dataloader() for padding + validator.data = check_det_dataset(dataset_yaml_path) + validator.init_metrics(de_parallel(model)) + + data_loader = validator.get_dataloader( + validator.data.get(validator.args.split), validator.args.batch + ) + + return validator, data_loader + + +def validate_yolo( + model: YOLO, exec_prog: ExecutorchProgramManager, dataset_yaml_path: str +) -> Dict[str, float]: + """ + Runs validation on a YOLO model using an ExecuTorch program and a dataset in Ultralytics format. + + :param model: The YOLO model instance to validate. + :param exec_prog: The ExecuTorch program manager containing the compiled model. + :param dataset_yaml_path: Path to the validation dataset file in Ultralytics .yaml format. + :return: Dictionary of validation statistics computed over the dataset. + """ + # Load model from buffer + runtime = Runtime.get() + program = runtime.load_program(exec_prog.buffer) + method = program.load_method("forward") + if method is None: + raise ValueError("Load method failed") + validator, data_loader = _prepare_validation(model, dataset_yaml_path) + print(f"Start validation on {dataset_yaml_path} dataset ...") + for batch in data_loader: + batch = validator.preprocess(batch) + preds = method.execute((batch["img"],)) + preds = validator.postprocess(preds) + validator.update_metrics(preds, batch) + stats = validator.get_stats() + return stats + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Export FP32 and INT8 Ultralytics Yolo models with executorch." + ) + parser.add_argument( + "--model_name", + type=str, + default="yolo12s", + choices=["yolo12n", "yolo12s", "yolo12m", "yolo12l", "yolo12x"], + help="Ultralytics yolo12 model name.", + ) + parser.add_argument( + "--input_dims", + type=eval, + default=[640, 640], + help="Input model dimensions in format [hight, weight] or (hight, weight). Default models dimensions are [640, 640]", + ) + parser.add_argument( + "--video_path", + type=str, + help="Path to the input video file to use for the quantization callibration.", + ) + parser.add_argument( + "--quantize", action="store_true", help="Enable model quantization." + ) + parser.add_argument( + "--subset_size", + type=int, + default=300, + help="Subset size for the quantized model calibration. The default value is 300.", + ) + parser.add_argument( + "--backend", + type=str, + default="openvino", + choices=["openvino", "xnnpack"], + help="Select the Executorch inference backend (openvino, xnnpack). openvino by default.", + ) + parser.add_argument( + "--device", + type=str, + default="CPU", + help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.", + ) + parser.add_argument( + "--validate", + nargs="?", + const="coco128.yaml", + help="Validate executorch model using the Ultralytics validation pipeline." + " Default validateion dataset is coco128.yaml.", + ) + + args = parser.parse_args() + + # Run the main function with parsed arguments + main( + model_name=args.model_name, + input_dims=args.input_dims, + quantize=args.quantize, + val_dataset_yaml_path=args.validate, + video_path=args.video_path, + subset_size=args.subset_size, + backend=args.backend, + device=args.device, + ) diff --git a/examples/models/yolo12/inference.h b/examples/models/yolo12/inference.h new file mode 100644 index 00000000000..467ef5ce0ca --- /dev/null +++ b/examples/models/yolo12/inference.h @@ -0,0 +1,151 @@ +#ifndef INFERENCE_H +#define INFERENCE_H + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::extension::from_blob; +using executorch::extension::Module; +using executorch::runtime::Error; +using executorch::runtime::Result; + +struct Detection { + int class_id{0}; + std::string className{}; + float confidence{0.0}; + cv::Rect box{}; +}; + +struct DetectionConfig { + std::vector classes; + float modelScoreThreshold; + float modelNMSThreshold; +}; + +cv::Mat scale_with_padding( + cv::Mat& source, + int* pad_x, + int* pad_y, + float* scale, + cv::Size img_dims) { + int col = source.cols; + int row = source.rows; + int m_inputWidth = img_dims.width; + int m_inputHeight = img_dims.height; + if (col == m_inputWidth and row == m_inputHeight) { + return source; + } + + *scale = std::min(m_inputWidth / (float)col, m_inputHeight / (float)row); + int resized_w = col * *scale; + int resized_h = row * *scale; + *pad_x = (m_inputWidth - resized_w) / 2; + *pad_y = (m_inputHeight - resized_h) / 2; + + cv::Mat resized; + cv::resize(source, resized, cv::Size(resized_w, resized_h)); + cv::Mat result = cv::Mat::zeros(m_inputHeight, m_inputWidth, source.type()); + resized.copyTo(result(cv::Rect(*pad_x, *pad_y, resized_w, resized_h))); + resized.release(); + return result; +} + +std::vector infer_yolo_once( + Module& module, + cv::Mat input, + cv::Size img_dims, + const DetectionConfig yolo_config) { + int pad_x, pad_y; + float scale; + input = scale_with_padding(input, &pad_x, &pad_y, &scale, img_dims); + + cv::Mat blob; + cv::dnn::blobFromImage( + input, blob, 1.0 / 255.0, img_dims, cv::Scalar(), true, false); + const auto t_input = from_blob( + (void*)blob.data, + std::vector(blob.size.p, blob.size.p + blob.dims), + ScalarType::Float); + const auto result = module.forward(t_input); + + ET_CHECK_MSG( + result.ok(), + "Execution of method forward failed with status 0x%" PRIx32, + (uint32_t)result.error()); + + const auto t = result->at(0).toTensor(); // Using only the 0 output + // yolov8 has an output of shape (batchSize, 84, 8400) (Num classes + + // box[x,y,w,h]) + cv::Mat mat_output(t.dim() - 1, t.sizes().data() + 1, CV_32FC1, t.data_ptr()); + + std::vector class_ids; + std::vector confidences; + std::vector boxes; + + // Iterate over detections and collect class IDs, confidence scores, and + // bounding boxes + for (int i = 0; i < mat_output.cols; ++i) { + const cv::Mat classes_scores = + mat_output.col(i).rowRange(4, mat_output.rows); + + cv::Point class_id; + double score; + cv::minMaxLoc( + classes_scores, + nullptr, + &score, + nullptr, + &class_id); // Find the class with the highest score + + // Check if the detection meets the confidence threshold + if (score <= yolo_config.modelScoreThreshold) + continue; + + class_ids.push_back(class_id.y); + confidences.push_back(score); + + const float x = mat_output.at(0, i); + const float y = mat_output.at(1, i); + const float w = mat_output.at(2, i); + const float h = mat_output.at(3, i); + + const int left = int((x - 0.5 * w - pad_x) / scale); + const int top = int((y - 0.5 * h - pad_y) / scale); + const int width = int(w / scale); + const int height = int(h / scale); + + boxes.push_back(cv::Rect(left, top, width, height)); + } + + std::vector nms_result; + cv::dnn::NMSBoxes( + boxes, + confidences, + yolo_config.modelScoreThreshold, + yolo_config.modelNMSThreshold, + nms_result); + + std::vector detections{}; + for (auto& idx : nms_result) { + Detection result; + result.class_id = class_ids[idx]; + result.confidence = confidences[idx]; + + result.className = yolo_config.classes[result.class_id]; + result.box = boxes[idx]; + + detections.push_back(result); + } + + return detections; +} +#endif // INFERENCE_H diff --git a/examples/models/yolo12/main.cpp b/examples/models/yolo12/main.cpp new file mode 100644 index 00000000000..95ea98d6634 --- /dev/null +++ b/examples/models/yolo12/main.cpp @@ -0,0 +1,168 @@ +#include "inference.h" + +#include + +void draw_detection( + cv::Mat& frame, + const Detection detection, + const cv::Scalar color); + +DetectionConfig DEFAULT_YOLO_CONFIG = { + {"person", "bicycle", "car", + "motorcycle", "airplane", "bus", + "train", "truck", "boat", + "traffic light", "fire hydrant", "stop sign", + "parking meter", "bench", "bird", + "cat", "dog", "horse", + "sheep", "cow", "elephant", + "bear", "zebra", "giraffe", + "backpack", "umbrella", "handbag", + "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", + "kite", "baseball bat", "baseball glove", + "skateboard", "surfboard", "tennis racket", + "bottle", "wine glass", "cup", + "fork", "knife", "spoon", + "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", + "carrot", "hot dog", "pizza", + "donut", "cake", "chair", + "couch", "potted plant", "bed", + "dining table", "toilet", "tv", + "laptop", "mouse", "remote", + "keyboard", "cell phone", "microwave", + "oven", "toaster", "sink", + "refrigerator", "book", "clock", + "vase", "scissors", "teddy bear", + "hair drier", "toothbrush"}, + 0.45, + 0.50}; + +DEFINE_string( + model_path, + "model.pte", + "Model serialized in flatbuffer format."); + +DEFINE_string(input_path, "input.mp4", "Path to the mp4 input video"); + +DEFINE_string(output_path, "output.mp4", "Path to the mp4 output video"); + +int main(int argc, char** argv) { + executorch::runtime::runtime_init(); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + // Use Mmap model to enable loading of big YOLO models in OpenVINO + Module yolo_module(FLAGS_model_path, Module::LoadMode::Mmap); + + auto error = yolo_module.load(); + + ET_CHECK_MSG( + error == Error::Ok, + "Loading of the model failed with status 0x%" PRIx32, + (uint32_t)error); + error = yolo_module.load_forward(); + ET_CHECK_MSG( + error == Error::Ok, + "Loading of the forward method failed with status 0x%" PRIx32, + (uint32_t)error); + + const auto model_input_shape = + yolo_module.method_meta("forward")->input_tensor_meta(0)->sizes(); + std::cout << "Model input shape: ["; + for (auto& dim : model_input_shape) { + std::cout << dim << ", "; + } + std::cout << "]" << std::endl; + const cv::Size img_dims = {model_input_shape[3], model_input_shape[2]}; + + cv::VideoCapture cap(FLAGS_input_path.c_str()); + if (!cap.isOpened()) { + std::cout << "Error opening video stream or file" << std::endl; + return -1; + } + const auto frame_width = cap.get(cv::CAP_PROP_FRAME_WIDTH); + const auto frame_height = cap.get(cv::CAP_PROP_FRAME_HEIGHT); + const auto video_lenght = cap.get(cv::CAP_PROP_FRAME_COUNT); + std::cout << "Input video shape: [3, " << frame_width << ", " << frame_height + << ", ]" << std::endl; + + cv::VideoWriter video( + FLAGS_output_path.c_str(), + cv::VideoWriter::fourcc('m', 'p', '4', 'v'), + 30, + cv::Size(frame_width, frame_height)); + + std::cout << "Start the detection..." << std::endl; + et_timestamp_t time_spent_executing = 0; + unsigned long long iters = 0; + // Show progress every 10% + unsigned long long progress_bar_tick = std::round(video_lenght / 10); + while (true) { + cv::Mat frame; + cap >> frame; + + if (frame.empty()) + break; + + const et_timestamp_t before_execute = et_pal_current_ticks(); + std::vector output = + infer_yolo_once(yolo_module, frame, img_dims, DEFAULT_YOLO_CONFIG); + + for (auto& detection : output) { + draw_detection(frame, detection, cv::Scalar(0, 0, 255)); + } + const et_timestamp_t after_execute = et_pal_current_ticks(); + time_spent_executing += after_execute - before_execute; + iters++; + + if (!(iters % progress_bar_tick)) { + const int precent_ready = (100 * iters) / video_lenght; + std::cout << iters << " out of " << video_lenght + << " frames are are processed (" << precent_ready << "\%)" + << std::endl; + } + video.write(frame); + } + + const auto tick_ratio = et_pal_ticks_to_ns_multiplier(); + constexpr auto NANOSECONDS_PER_MILLISECOND = 1000000; + + double elapsed_ms = static_cast(time_spent_executing) * + tick_ratio.numerator / tick_ratio.denominator / + NANOSECONDS_PER_MILLISECOND; + std::cout << "Model executed successfully " << iters << " times in " + << elapsed_ms << " ms." << std::endl; + std::cout << "Average detection time: " << elapsed_ms / iters << " ms." + << std::endl; + cap.release(); + video.release(); +} + +void draw_detection( + cv::Mat& frame, + const Detection detection, + const cv::Scalar color) { + cv::Rect box = detection.box; + + // Detection box + cv::rectangle(frame, box, color, 2); + + // Detection box text + std::string classString = detection.className + ' ' + + std::to_string(detection.confidence).substr(0, 4); + cv::Size textSize = + cv::getTextSize(classString, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0); + cv::Rect textBox( + box.x, box.y - 40, textSize.width + 10, textSize.height + 20); + + cv::rectangle(frame, textBox, color, cv::FILLED); + cv::putText( + frame, + classString, + cv::Point(box.x + 5, box.y - 10), + cv::FONT_HERSHEY_DUPLEX, + 1, + cv::Scalar(0, 0, 0), + 2, + 0); +} \ No newline at end of file diff --git a/examples/models/yolo12/requirements.txt b/examples/models/yolo12/requirements.txt new file mode 100644 index 00000000000..de537f46170 --- /dev/null +++ b/examples/models/yolo12/requirements.txt @@ -0,0 +1 @@ +ultralytics==8.3.97 \ No newline at end of file diff --git a/examples/models/yolo12/yolo12s_demo.gif b/examples/models/yolo12/yolo12s_demo.gif new file mode 100644 index 00000000000..be029bf416c Binary files /dev/null and b/examples/models/yolo12/yolo12s_demo.gif differ