|
1 | 1 | # coding=utf-8
|
2 |
| -# Copyright 2023 The HuggingFace Team. All rights reserved. |
| 2 | +# Copyright 2025 The HuggingFace Team. All rights reserved. |
3 | 3 | #
|
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | # you may not use this file except in compliance with the License.
|
|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
| 15 | +"""Pipelines running different backends.""" |
15 | 16 |
|
16 |
| -from .pipelines_base import ( |
17 |
| - MAPPING_LOADING_FUNC, |
18 |
| - ORT_SUPPORTED_TASKS, |
19 |
| - load_ort_pipeline, |
20 |
| - pipeline, |
| 17 | +from typing import TYPE_CHECKING, Any, Optional, Union |
| 18 | + |
| 19 | +import torch |
| 20 | + |
| 21 | +from optimum.utils.import_utils import ( |
| 22 | + is_ipex_available, |
| 23 | + is_onnxruntime_available, |
| 24 | + is_openvino_available, |
| 25 | + is_optimum_intel_available, |
| 26 | + is_optimum_onnx_available, |
21 | 27 | )
|
| 28 | + |
| 29 | + |
| 30 | +if TYPE_CHECKING: |
| 31 | + from transformers import ( |
| 32 | + BaseImageProcessor, |
| 33 | + FeatureExtractionMixin, |
| 34 | + Pipeline, |
| 35 | + PretrainedConfig, |
| 36 | + PreTrainedModel, |
| 37 | + PreTrainedTokenizer, |
| 38 | + PreTrainedTokenizerFast, |
| 39 | + ProcessorMixin, |
| 40 | + TFPreTrainedModel, |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +# The docstring is simply a copy of transformers.pipelines.pipeline's doc with minor modifications |
| 45 | +# to reflect the fact that this pipeline loads Accelerated models using optimum. |
| 46 | +def pipeline( |
| 47 | + task: Optional[str] = None, |
| 48 | + model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, |
| 49 | + config: Optional[Union[str, "PretrainedConfig"]] = None, |
| 50 | + tokenizer: Optional[Union[str, "PreTrainedTokenizer", "PreTrainedTokenizerFast"]] = None, |
| 51 | + feature_extractor: Optional[Union[str, "FeatureExtractionMixin "]] = None, |
| 52 | + image_processor: Optional[Union[str, "BaseImageProcessor"]] = None, |
| 53 | + processor: Optional[Union[str, "ProcessorMixin"]] = None, |
| 54 | + framework: Optional[str] = None, |
| 55 | + revision: Optional[str] = None, |
| 56 | + use_fast: bool = True, |
| 57 | + token: Optional[Union[str, bool]] = None, |
| 58 | + device: Optional[Union[int, str, "torch.device"]] = None, |
| 59 | + device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, |
| 60 | + torch_dtype: Optional[Union[str, "torch.dtype"]] = "auto", |
| 61 | + trust_remote_code: Optional[bool] = None, |
| 62 | + model_kwargs: Optional[dict[str, Any]] = None, |
| 63 | + pipeline_class: Optional[Any] = None, |
| 64 | + accelerator: Optional[str] = None, |
| 65 | + **kwargs: Any, |
| 66 | +) -> "Pipeline": |
| 67 | + """Utility factory method to build a [`Pipeline`] with an optimum accelerated model, similar to `transformers.pipeline`. |
| 68 | + A pipeline consists of: |
| 69 | + - One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer), |
| 70 | + [image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors). |
| 71 | + - A [model](model) that generates predictions from the inputs. |
| 72 | + - Optional post-processing steps to refine the model's output, which can also be handled by processors. |
| 73 | + <Tip> |
| 74 | + While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`, |
| 75 | + they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load |
| 76 | + required ones automatically. In case you want to provide these components explicitly, please refer to a |
| 77 | + specific pipeline in order to get more details regarding what components are required. |
| 78 | + </Tip> |
| 79 | + Args: |
| 80 | + task (`str`): |
| 81 | + The task defining which pipeline will be returned. Currently accepted tasks are: |
| 82 | + - `"audio-classification"`: will return a [`AudioClassificationPipeline`]. |
| 83 | + - `"automatic-speech-recognition"`: will return a [`AutomaticSpeechRecognitionPipeline`]. |
| 84 | + - `"depth-estimation"`: will return a [`DepthEstimationPipeline`]. |
| 85 | + - `"document-question-answering"`: will return a [`DocumentQuestionAnsweringPipeline`]. |
| 86 | + - `"feature-extraction"`: will return a [`FeatureExtractionPipeline`]. |
| 87 | + - `"fill-mask"`: will return a [`FillMaskPipeline`]:. |
| 88 | + - `"image-classification"`: will return a [`ImageClassificationPipeline`]. |
| 89 | + - `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`]. |
| 90 | + - `"image-segmentation"`: will return a [`ImageSegmentationPipeline`]. |
| 91 | + - `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`]. |
| 92 | + - `"image-to-image"`: will return a [`ImageToImagePipeline`]. |
| 93 | + - `"image-to-text"`: will return a [`ImageToTextPipeline`]. |
| 94 | + - `"mask-generation"`: will return a [`MaskGenerationPipeline`]. |
| 95 | + - `"object-detection"`: will return a [`ObjectDetectionPipeline`]. |
| 96 | + - `"question-answering"`: will return a [`QuestionAnsweringPipeline`]. |
| 97 | + - `"summarization"`: will return a [`SummarizationPipeline`]. |
| 98 | + - `"table-question-answering"`: will return a [`TableQuestionAnsweringPipeline`]. |
| 99 | + - `"text2text-generation"`: will return a [`Text2TextGenerationPipeline`]. |
| 100 | + - `"text-classification"` (alias `"sentiment-analysis"` available): will return a |
| 101 | + [`TextClassificationPipeline`]. |
| 102 | + - `"text-generation"`: will return a [`TextGenerationPipeline`]:. |
| 103 | + - `"text-to-audio"` (alias `"text-to-speech"` available): will return a [`TextToAudioPipeline`]:. |
| 104 | + - `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`]. |
| 105 | + - `"translation"`: will return a [`TranslationPipeline`]. |
| 106 | + - `"translation_xx_to_yy"`: will return a [`TranslationPipeline`]. |
| 107 | + - `"video-classification"`: will return a [`VideoClassificationPipeline`]. |
| 108 | + - `"visual-question-answering"`: will return a [`VisualQuestionAnsweringPipeline`]. |
| 109 | + - `"zero-shot-classification"`: will return a [`ZeroShotClassificationPipeline`]. |
| 110 | + - `"zero-shot-image-classification"`: will return a [`ZeroShotImageClassificationPipeline`]. |
| 111 | + - `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`]. |
| 112 | + - `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`]. |
| 113 | + model (`str` or [`ORTModel` or `OVModel`], *optional*): |
| 114 | + The model that will be used by the pipeline to make predictions. This can be a model identifier or an |
| 115 | + actual instance of a ONNX Runtime model inheriting from [`ORTModel` or `OVModel`]. |
| 116 | + If not provided, the default for the `task` will be loaded. |
| 117 | + config (`str` or [`PretrainedConfig`], *optional*): |
| 118 | + The configuration that will be used by the pipeline to instantiate the model. This can be a model |
| 119 | + identifier or an actual pretrained model configuration inheriting from [`PretrainedConfig`]. |
| 120 | + If not provided, the default configuration file for the requested model will be used. That means that if |
| 121 | + `model` is given, its default configuration will be used. However, if `model` is not supplied, this |
| 122 | + `task`'s default model's config is used instead. |
| 123 | + tokenizer (`str` or [`PreTrainedTokenizer`], *optional*): |
| 124 | + The tokenizer that will be used by the pipeline to encode data for the model. This can be a model |
| 125 | + identifier or an actual pretrained tokenizer inheriting from [`PreTrainedTokenizer`]. |
| 126 | + If not provided, the default tokenizer for the given `model` will be loaded (if it is a string). If `model` |
| 127 | + is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string). |
| 128 | + However, if `config` is also not given or not a string, then the default tokenizer for the given `task` |
| 129 | + will be loaded. |
| 130 | + feature_extractor (`str` or [`PreTrainedFeatureExtractor`], *optional*): |
| 131 | + The feature extractor that will be used by the pipeline to encode data for the model. This can be a model |
| 132 | + identifier or an actual pretrained feature extractor inheriting from [`PreTrainedFeatureExtractor`]. |
| 133 | + Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal |
| 134 | + models. Multi-modal models will also require a tokenizer to be passed. |
| 135 | + If not provided, the default feature extractor for the given `model` will be loaded (if it is a string). If |
| 136 | + `model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it |
| 137 | + is a string). However, if `config` is also not given or not a string, then the default feature extractor |
| 138 | + for the given `task` will be loaded. |
| 139 | + image_processor (`str` or [`BaseImageProcessor`], *optional*): |
| 140 | + The image processor that will be used by the pipeline to preprocess images for the model. This can be a |
| 141 | + model identifier or an actual image processor inheriting from [`BaseImageProcessor`]. |
| 142 | + Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal |
| 143 | + models will also require a tokenizer to be passed. |
| 144 | + If not provided, the default image processor for the given `model` will be loaded (if it is a string). If |
| 145 | + `model` is not specified or not a string, then the default image processor for `config` is loaded (if it is |
| 146 | + a string). |
| 147 | + processor (`str` or [`ProcessorMixin`], *optional*): |
| 148 | + The processor that will be used by the pipeline to preprocess data for the model. This can be a model |
| 149 | + identifier or an actual processor inheriting from [`ProcessorMixin`]. |
| 150 | + Processors are used for multi-modal models that require multi-modal inputs, for example, a model that |
| 151 | + requires both text and image inputs. |
| 152 | + If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model` |
| 153 | + is not specified or not a string, then the default processor for `config` is loaded (if it is a string). |
| 154 | + framework (`str`, *optional*): |
| 155 | + The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be |
| 156 | + installed. |
| 157 | + If no framework is specified, will default to the one currently installed. If no framework is specified and |
| 158 | + both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is |
| 159 | + provided. |
| 160 | + revision (`str`, *optional*, defaults to `"main"`): |
| 161 | + When passing a task name or a string model identifier: The specific model version to use. It can be a |
| 162 | + branch name, a tag name, or a commit id, since we use a git-based system for storing models and other |
| 163 | + artifacts on huggingface.co, so `revision` can be any identifier allowed by git. |
| 164 | + use_fast (`bool`, *optional*, defaults to `True`): |
| 165 | + Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]). |
| 166 | + use_auth_token (`str` or *bool*, *optional*): |
| 167 | + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated |
| 168 | + when running `hf auth login` (stored in `~/.huggingface`). |
| 169 | + device (`int` or `str` or `torch.device`): |
| 170 | + Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this |
| 171 | + pipeline will be allocated. |
| 172 | + device_map (`str` or `dict[str, Union[int, str, torch.device]`, *optional*): |
| 173 | + Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set |
| 174 | + `device_map="auto"` to compute the most optimized `device_map` automatically (see |
| 175 | + [here](https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.cpu_offload) |
| 176 | + for more information). |
| 177 | + <Tip warning={true}> |
| 178 | + Do not use `device_map` AND `device` at the same time as they will conflict |
| 179 | + </Tip> |
| 180 | + torch_dtype (`str` or `torch.dtype`, *optional*): |
| 181 | + Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model |
| 182 | + (`torch.float16`, `torch.bfloat16`, ... or `"auto"`). |
| 183 | + trust_remote_code (`bool`, *optional*, defaults to `False`): |
| 184 | + Whether or not to allow for custom code defined on the Hub in their own modeling, configuration, |
| 185 | + tokenization or even pipeline files. This option should only be set to `True` for repositories you trust |
| 186 | + and in which you have read the code, as it will execute code present on the Hub on your local machine. |
| 187 | + model_kwargs (`dict[str, Any]`, *optional*): |
| 188 | + Additional dictionary of keyword arguments passed along to the model's `from_pretrained(..., |
| 189 | + **model_kwargs)` function. |
| 190 | + pipeline_class (`type`, *optional*): |
| 191 | + Can be used to force using a custom pipeline class. If not provided, the default pipeline class for the |
| 192 | + specified task will be used. |
| 193 | + accelerator (`str`, *optional*): |
| 194 | + The accelerator to use, either `"ort"` for ONNX Runtime, `"ov"` for OpenVINO, or `"ipex"` for Intel |
| 195 | + Extension for PyTorch. If no accelerator is specified, will default to the one currently installed/available. |
| 196 | + kwargs (`dict[str, Any]`, *optional*): |
| 197 | + Additional keyword arguments passed along to the specific pipeline init (see the documentation for the |
| 198 | + corresponding pipeline class for possible values). |
| 199 | + Returns: |
| 200 | + [`Pipeline`]: A suitable pipeline for the task. |
| 201 | + Examples: |
| 202 | + ```python |
| 203 | + >>> from optimum.pipelines import pipeline |
| 204 | + >>> # Sentiment analysis pipeline with default model, using OpenVINO |
| 205 | + >>> analyzer = pipeline("sentiment-analysis", accelerator="ov") |
| 206 | + >>> # Question answering pipeline, specifying the checkpoint identifier, with IPEX |
| 207 | + >>> oracle = pipeline( |
| 208 | + ... "question-answering", model="distilbert/distilbert-base-cased-distilled-squad", tokenizer="google-bert/bert-base-cased", accelerator="ipex" |
| 209 | + ... ) |
| 210 | + >>> # Named entity recognition pipeline, passing in a specific model and tokenizer, with ONNX Runtime |
| 211 | + >>> model = ORTModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english") |
| 212 | + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased") |
| 213 | + >>> recognizer = pipeline("ner", model=model, tokenizer=tokenizer) |
| 214 | + ``` |
| 215 | + """ |
| 216 | + |
| 217 | + if accelerator is None: |
| 218 | + # probably needs to check for couple of stuff here, like target device, type(model) etc. |
| 219 | + if is_optimum_intel_available() and is_openvino_available(): |
| 220 | + accelerator = "ov" |
| 221 | + elif is_optimum_onnx_available() and is_onnxruntime_available(): |
| 222 | + accelerator = "ort" |
| 223 | + elif is_optimum_intel_available() and is_ipex_available(): |
| 224 | + accelerator = "ipex" |
| 225 | + else: |
| 226 | + raise ImportError( |
| 227 | + "You need to install either optimum-onnx[onnxruntime], optimum-intel[openvino], or optimum-intel[ipex] to use this pipeline." |
| 228 | + ) |
| 229 | + |
| 230 | + if accelerator == "ort": |
| 231 | + if not (is_optimum_onnx_available() and is_onnxruntime_available()): |
| 232 | + raise ImportError("You need to install the `optimum-onnx[onnxruntime]` to use ONNX Runtime models.") |
| 233 | + |
| 234 | + from optimum.onnxruntime import pipeline as ort_pipeline |
| 235 | + |
| 236 | + return ort_pipeline( |
| 237 | + task=task, |
| 238 | + model=model, |
| 239 | + config=config, |
| 240 | + tokenizer=tokenizer, |
| 241 | + feature_extractor=feature_extractor, |
| 242 | + image_processor=image_processor, |
| 243 | + processor=processor, |
| 244 | + framework=framework, |
| 245 | + revision=revision, |
| 246 | + use_fast=use_fast, |
| 247 | + token=token, |
| 248 | + device=device, |
| 249 | + device_map=device_map, |
| 250 | + torch_dtype=torch_dtype, |
| 251 | + trust_remote_code=trust_remote_code, |
| 252 | + model_kwargs=model_kwargs, |
| 253 | + pipeline_class=pipeline_class, |
| 254 | + **kwargs, |
| 255 | + ) |
| 256 | + elif accelerator in ["ov", "ipex"]: |
| 257 | + if accelerator == "ov" and not (is_optimum_intel_available() and is_openvino_available()): |
| 258 | + raise ImportError("You need to install the `optimum-intel[openvino]` to use OpenVINO models.") |
| 259 | + elif accelerator == "ipex" and not (is_optimum_intel_available() and is_ipex_available()): |
| 260 | + raise ImportError( |
| 261 | + "You need to install the `optimum-intel[ipex]` to use Intel Extension for PyTorch models." |
| 262 | + ) |
| 263 | + |
| 264 | + from optimum.intel import pipeline as intel_pipeline |
| 265 | + |
| 266 | + return intel_pipeline( |
| 267 | + task=task, |
| 268 | + model=model, |
| 269 | + config=config, |
| 270 | + tokenizer=tokenizer, |
| 271 | + feature_extractor=feature_extractor, |
| 272 | + image_processor=image_processor, |
| 273 | + processor=processor, |
| 274 | + framework=framework, |
| 275 | + revision=revision, |
| 276 | + use_fast=use_fast, |
| 277 | + token=token, |
| 278 | + device=device, |
| 279 | + device_map=device_map, |
| 280 | + torch_dtype=torch_dtype, |
| 281 | + trust_remote_code=trust_remote_code, |
| 282 | + model_kwargs=model_kwargs, |
| 283 | + pipeline_class=pipeline_class, |
| 284 | + accelerator=accelerator, |
| 285 | + **kwargs, |
| 286 | + ) |
| 287 | + else: |
| 288 | + raise ValueError(f"Accelerator {accelerator} not recognized. Please use 'ort', 'ov' or 'ipex'.") |
0 commit comments