Skip to content

Commit b5873e2

Browse files
optimum pipelines
1 parent 2a269df commit b5873e2

File tree

6 files changed

+295
-476
lines changed

6 files changed

+295
-476
lines changed

docs/source/quicktour.mdx

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -129,42 +129,6 @@ To train transformers on Habana's Gaudi processors, 🤗 Optimum provides a `Gau
129129

130130
You can find more examples in the [documentation](https://huggingface.co/docs/optimum/habana/quickstart) and in the [examples](https://github.com/huggingface/optimum-habana/tree/main/examples).
131131

132-
133-
#### ONNX Runtime
134-
135-
To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum provides a `ORTTrainer` that is very similar to the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example:
136-
137-
```diff
138-
- from transformers import Trainer, TrainingArguments
139-
+ from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments
140-
141-
# Download a pretrained model from the Hub
142-
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
143-
144-
# Define the training arguments
145-
- training_args = TrainingArguments(
146-
+ training_args = ORTTrainingArguments(
147-
output_dir="path/to/save/folder/",
148-
optim="adamw_ort_fused",
149-
...
150-
)
151-
152-
# Create a ONNX Runtime Trainer
153-
- trainer = Trainer(
154-
+ trainer = ORTTrainer(
155-
model=model,
156-
args=training_args,
157-
train_dataset=train_dataset,
158-
+ feature="text-classification", # The model type to export to ONNX
159-
...
160-
)
161-
162-
# Use ONNX Runtime for training!
163-
trainer.train()
164-
```
165-
166-
You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training).
167-
168132
## Out of the box ONNX export
169133

170134
The Optimum library handles out of the box the ONNX export of Transformers and Diffusers models!

optimum/exporters/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
from ..utils import (
2828
DIFFUSERS_MINIMUM_VERSION,
29-
check_if_diffusers_greater,
3029
is_diffusers_available,
30+
is_diffusers_version,
3131
logging,
3232
)
3333
from ..utils.import_utils import _diffusers_version
@@ -38,7 +38,7 @@
3838

3939

4040
if is_diffusers_available():
41-
if not check_if_diffusers_greater(DIFFUSERS_MINIMUM_VERSION.base_version):
41+
if is_diffusers_version("<", DIFFUSERS_MINIMUM_VERSION.base_version):
4242
raise ImportError(
4343
f"We found an older version of diffusers {_diffusers_version} but we require diffusers to be >= {DIFFUSERS_MINIMUM_VERSION}. "
4444
"Please update diffusers by running `pip install --upgrade diffusers`"

optimum/pipelines/__init__.py

Lines changed: 273 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
# Copyright 2025 The HuggingFace Team. All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -12,10 +12,277 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""Pipelines running different backends."""
1516

16-
from .pipelines_base import (
17-
MAPPING_LOADING_FUNC,
18-
ORT_SUPPORTED_TASKS,
19-
load_ort_pipeline,
20-
pipeline,
17+
from typing import TYPE_CHECKING, Any, Optional, Union
18+
19+
import torch
20+
21+
from optimum.utils.import_utils import (
22+
is_ipex_available,
23+
is_onnxruntime_available,
24+
is_openvino_available,
25+
is_optimum_intel_available,
26+
is_optimum_onnx_available,
2127
)
28+
29+
30+
if TYPE_CHECKING:
31+
from transformers import (
32+
BaseImageProcessor,
33+
FeatureExtractionMixin,
34+
Pipeline,
35+
PretrainedConfig,
36+
PreTrainedModel,
37+
PreTrainedTokenizer,
38+
PreTrainedTokenizerFast,
39+
ProcessorMixin,
40+
TFPreTrainedModel,
41+
)
42+
43+
44+
# The docstring is simply a copy of transformers.pipelines.pipeline's doc with minor modifications
45+
# to reflect the fact that this pipeline loads Accelerated models using optimum.
46+
def pipeline(
47+
task: Optional[str] = None,
48+
model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
49+
config: Optional[Union[str, "PretrainedConfig"]] = None,
50+
tokenizer: Optional[Union[str, "PreTrainedTokenizer", "PreTrainedTokenizerFast"]] = None,
51+
feature_extractor: Optional[Union[str, "FeatureExtractionMixin "]] = None,
52+
image_processor: Optional[Union[str, "BaseImageProcessor"]] = None,
53+
processor: Optional[Union[str, "ProcessorMixin"]] = None,
54+
framework: Optional[str] = None,
55+
revision: Optional[str] = None,
56+
use_fast: bool = True,
57+
token: Optional[Union[str, bool]] = None,
58+
device: Optional[Union[int, str, "torch.device"]] = None,
59+
device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None,
60+
torch_dtype: Optional[Union[str, "torch.dtype"]] = "auto",
61+
trust_remote_code: Optional[bool] = None,
62+
model_kwargs: Optional[dict[str, Any]] = None,
63+
pipeline_class: Optional[Any] = None,
64+
accelerator: Optional[str] = None,
65+
**kwargs: Any,
66+
) -> "Pipeline":
67+
"""Utility factory method to build a [`Pipeline`] with an optimum accelerated model, similar to `transformers.pipeline`.
68+
A pipeline consists of:
69+
- One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer),
70+
[image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors).
71+
- A [model](model) that generates predictions from the inputs.
72+
- Optional post-processing steps to refine the model's output, which can also be handled by processors.
73+
<Tip>
74+
While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`,
75+
they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load
76+
required ones automatically. In case you want to provide these components explicitly, please refer to a
77+
specific pipeline in order to get more details regarding what components are required.
78+
</Tip>
79+
Args:
80+
task (`str`):
81+
The task defining which pipeline will be returned. Currently accepted tasks are:
82+
- `"audio-classification"`: will return a [`AudioClassificationPipeline`].
83+
- `"automatic-speech-recognition"`: will return a [`AutomaticSpeechRecognitionPipeline`].
84+
- `"depth-estimation"`: will return a [`DepthEstimationPipeline`].
85+
- `"document-question-answering"`: will return a [`DocumentQuestionAnsweringPipeline`].
86+
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
87+
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
88+
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
89+
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
90+
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
91+
- `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`].
92+
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
93+
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
94+
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
95+
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
96+
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
97+
- `"summarization"`: will return a [`SummarizationPipeline`].
98+
- `"table-question-answering"`: will return a [`TableQuestionAnsweringPipeline`].
99+
- `"text2text-generation"`: will return a [`Text2TextGenerationPipeline`].
100+
- `"text-classification"` (alias `"sentiment-analysis"` available): will return a
101+
[`TextClassificationPipeline`].
102+
- `"text-generation"`: will return a [`TextGenerationPipeline`]:.
103+
- `"text-to-audio"` (alias `"text-to-speech"` available): will return a [`TextToAudioPipeline`]:.
104+
- `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`].
105+
- `"translation"`: will return a [`TranslationPipeline`].
106+
- `"translation_xx_to_yy"`: will return a [`TranslationPipeline`].
107+
- `"video-classification"`: will return a [`VideoClassificationPipeline`].
108+
- `"visual-question-answering"`: will return a [`VisualQuestionAnsweringPipeline`].
109+
- `"zero-shot-classification"`: will return a [`ZeroShotClassificationPipeline`].
110+
- `"zero-shot-image-classification"`: will return a [`ZeroShotImageClassificationPipeline`].
111+
- `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`].
112+
- `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`].
113+
model (`str` or [`ORTModel` or `OVModel`], *optional*):
114+
The model that will be used by the pipeline to make predictions. This can be a model identifier or an
115+
actual instance of a ONNX Runtime model inheriting from [`ORTModel` or `OVModel`].
116+
If not provided, the default for the `task` will be loaded.
117+
config (`str` or [`PretrainedConfig`], *optional*):
118+
The configuration that will be used by the pipeline to instantiate the model. This can be a model
119+
identifier or an actual pretrained model configuration inheriting from [`PretrainedConfig`].
120+
If not provided, the default configuration file for the requested model will be used. That means that if
121+
`model` is given, its default configuration will be used. However, if `model` is not supplied, this
122+
`task`'s default model's config is used instead.
123+
tokenizer (`str` or [`PreTrainedTokenizer`], *optional*):
124+
The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
125+
identifier or an actual pretrained tokenizer inheriting from [`PreTrainedTokenizer`].
126+
If not provided, the default tokenizer for the given `model` will be loaded (if it is a string). If `model`
127+
is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string).
128+
However, if `config` is also not given or not a string, then the default tokenizer for the given `task`
129+
will be loaded.
130+
feature_extractor (`str` or [`PreTrainedFeatureExtractor`], *optional*):
131+
The feature extractor that will be used by the pipeline to encode data for the model. This can be a model
132+
identifier or an actual pretrained feature extractor inheriting from [`PreTrainedFeatureExtractor`].
133+
Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal
134+
models. Multi-modal models will also require a tokenizer to be passed.
135+
If not provided, the default feature extractor for the given `model` will be loaded (if it is a string). If
136+
`model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it
137+
is a string). However, if `config` is also not given or not a string, then the default feature extractor
138+
for the given `task` will be loaded.
139+
image_processor (`str` or [`BaseImageProcessor`], *optional*):
140+
The image processor that will be used by the pipeline to preprocess images for the model. This can be a
141+
model identifier or an actual image processor inheriting from [`BaseImageProcessor`].
142+
Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal
143+
models will also require a tokenizer to be passed.
144+
If not provided, the default image processor for the given `model` will be loaded (if it is a string). If
145+
`model` is not specified or not a string, then the default image processor for `config` is loaded (if it is
146+
a string).
147+
processor (`str` or [`ProcessorMixin`], *optional*):
148+
The processor that will be used by the pipeline to preprocess data for the model. This can be a model
149+
identifier or an actual processor inheriting from [`ProcessorMixin`].
150+
Processors are used for multi-modal models that require multi-modal inputs, for example, a model that
151+
requires both text and image inputs.
152+
If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model`
153+
is not specified or not a string, then the default processor for `config` is loaded (if it is a string).
154+
framework (`str`, *optional*):
155+
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
156+
installed.
157+
If no framework is specified, will default to the one currently installed. If no framework is specified and
158+
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
159+
provided.
160+
revision (`str`, *optional*, defaults to `"main"`):
161+
When passing a task name or a string model identifier: The specific model version to use. It can be a
162+
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
163+
artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
164+
use_fast (`bool`, *optional*, defaults to `True`):
165+
Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]).
166+
use_auth_token (`str` or *bool*, *optional*):
167+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
168+
when running `hf auth login` (stored in `~/.huggingface`).
169+
device (`int` or `str` or `torch.device`):
170+
Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this
171+
pipeline will be allocated.
172+
device_map (`str` or `dict[str, Union[int, str, torch.device]`, *optional*):
173+
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
174+
`device_map="auto"` to compute the most optimized `device_map` automatically (see
175+
[here](https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.cpu_offload)
176+
for more information).
177+
<Tip warning={true}>
178+
Do not use `device_map` AND `device` at the same time as they will conflict
179+
</Tip>
180+
torch_dtype (`str` or `torch.dtype`, *optional*):
181+
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
182+
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
183+
trust_remote_code (`bool`, *optional*, defaults to `False`):
184+
Whether or not to allow for custom code defined on the Hub in their own modeling, configuration,
185+
tokenization or even pipeline files. This option should only be set to `True` for repositories you trust
186+
and in which you have read the code, as it will execute code present on the Hub on your local machine.
187+
model_kwargs (`dict[str, Any]`, *optional*):
188+
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
189+
**model_kwargs)` function.
190+
pipeline_class (`type`, *optional*):
191+
Can be used to force using a custom pipeline class. If not provided, the default pipeline class for the
192+
specified task will be used.
193+
accelerator (`str`, *optional*):
194+
The accelerator to use, either `"ort"` for ONNX Runtime, `"ov"` for OpenVINO, or `"ipex"` for Intel
195+
Extension for PyTorch. If no accelerator is specified, will default to the one currently installed/available.
196+
kwargs (`dict[str, Any]`, *optional*):
197+
Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
198+
corresponding pipeline class for possible values).
199+
Returns:
200+
[`Pipeline`]: A suitable pipeline for the task.
201+
Examples:
202+
```python
203+
>>> from optimum.pipelines import pipeline
204+
>>> # Sentiment analysis pipeline with default model, using OpenVINO
205+
>>> analyzer = pipeline("sentiment-analysis", accelerator="ov")
206+
>>> # Question answering pipeline, specifying the checkpoint identifier, with IPEX
207+
>>> oracle = pipeline(
208+
... "question-answering", model="distilbert/distilbert-base-cased-distilled-squad", tokenizer="google-bert/bert-base-cased", accelerator="ipex"
209+
... )
210+
>>> # Named entity recognition pipeline, passing in a specific model and tokenizer, with ONNX Runtime
211+
>>> model = ORTModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
212+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
213+
>>> recognizer = pipeline("ner", model=model, tokenizer=tokenizer)
214+
```
215+
"""
216+
217+
if accelerator is None:
218+
# probably needs to check for couple of stuff here, like target device, type(model) etc.
219+
if is_optimum_intel_available() and is_openvino_available():
220+
accelerator = "ov"
221+
elif is_optimum_onnx_available() and is_onnxruntime_available():
222+
accelerator = "ort"
223+
elif is_optimum_intel_available() and is_ipex_available():
224+
accelerator = "ipex"
225+
else:
226+
raise ImportError(
227+
"You need to install either optimum-onnx[onnxruntime], optimum-intel[openvino], or optimum-intel[ipex] to use this pipeline."
228+
)
229+
230+
if accelerator == "ort":
231+
if not (is_optimum_onnx_available() and is_onnxruntime_available()):
232+
raise ImportError("You need to install the `optimum-onnx[onnxruntime]` to use ONNX Runtime models.")
233+
234+
from optimum.onnxruntime import pipeline as ort_pipeline
235+
236+
return ort_pipeline(
237+
task=task,
238+
model=model,
239+
config=config,
240+
tokenizer=tokenizer,
241+
feature_extractor=feature_extractor,
242+
image_processor=image_processor,
243+
processor=processor,
244+
framework=framework,
245+
revision=revision,
246+
use_fast=use_fast,
247+
token=token,
248+
device=device,
249+
device_map=device_map,
250+
torch_dtype=torch_dtype,
251+
trust_remote_code=trust_remote_code,
252+
model_kwargs=model_kwargs,
253+
pipeline_class=pipeline_class,
254+
**kwargs,
255+
)
256+
elif accelerator in ["ov", "ipex"]:
257+
if accelerator == "ov" and not (is_optimum_intel_available() and is_openvino_available()):
258+
raise ImportError("You need to install the `optimum-intel[openvino]` to use OpenVINO models.")
259+
elif accelerator == "ipex" and not (is_optimum_intel_available() and is_ipex_available()):
260+
raise ImportError(
261+
"You need to install the `optimum-intel[ipex]` to use Intel Extension for PyTorch models."
262+
)
263+
264+
from optimum.intel import pipeline as intel_pipeline
265+
266+
return intel_pipeline(
267+
task=task,
268+
model=model,
269+
config=config,
270+
tokenizer=tokenizer,
271+
feature_extractor=feature_extractor,
272+
image_processor=image_processor,
273+
processor=processor,
274+
framework=framework,
275+
revision=revision,
276+
use_fast=use_fast,
277+
token=token,
278+
device=device,
279+
device_map=device_map,
280+
torch_dtype=torch_dtype,
281+
trust_remote_code=trust_remote_code,
282+
model_kwargs=model_kwargs,
283+
pipeline_class=pipeline_class,
284+
accelerator=accelerator,
285+
**kwargs,
286+
)
287+
else:
288+
raise ValueError(f"Accelerator {accelerator} not recognized. Please use 'ort', 'ov' or 'ipex'.")

0 commit comments

Comments
 (0)