Skip to content

Commit 16618fc

Browse files
Introduce ORTSessionMixin and enable general io binding (works for diffusers as well) (#2234)
* refactor ort session mixin and enable integral and simple diffusers io binding as a result * style * distribute onnxruntime tests * no need to clean disk for fast tests * move diffusion tests to diffusion * fix * test * providers * fix * fix and get rid of io bindign helpers * get rid of the models folder in onnxruntime * style * comments * remove _from_transformers * fix trust remote code * fix local model test (use_cache only should be pased to causal lm) * reuse input buffers as output buffers in diffusion models * fix sess_option * override libarary for sentence transformers feature etraction * fix no cross attention * remove lib and default to transformers * better * flaky decoder * add saving session utils * test * more refactoring attempt (seperating encoder/decode parts deom parent model) * remove print * fix * _infer_onnx_filename as private method * fix model name * fix more model paths * fixes * Update optimum/onnxruntime/base.py Co-authored-by: Ella Charlaix <[email protected]> * added main export library guards, and restricted when to force eager * style * review suggestions and added warning when properties are not consistent across * ORTParentMixin * deprecate instantiating ORTModel, ORTModelForCausalLM and ORTModelForConditionalGeneration with positional arguments. * keyword arguments * diffusion * known output buffers * style * typo * id * fix * more typos * Update optimum/onnxruntime/modeling_diffusion.py Co-authored-by: Ella Charlaix <[email protected]> * remove task argument from _export * deprecate and fix * fix * style * slim later * misc fixes and extensions in testing * allow passing export arguments to diffusion pipeline (e.g. exports on cuda device, with specific dtype, etc) --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent d1af494 commit 16618fc

22 files changed

+2589
-2874
lines changed

.github/workflows/test_onnxruntime.yml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,20 @@ jobs:
2727
matrix:
2828
python-version: [3.9]
2929
runs-on: [ubuntu-22.04]
30+
test_file: [
31+
test_timm.py,
32+
test_modeling.py, # todo: split into test_encoder, test_decoder and test_encoder_decoder
33+
test_diffusion.py,
34+
test_optimization.py,
35+
test_quantization.py,
36+
test_utils.py,
37+
]
3038

3139
runs-on: ${{ matrix.runs-on }}
3240

3341
steps:
3442
- name: Free Disk Space (Ubuntu)
35-
if: matrix.runs-on == 'ubuntu-22.04'
43+
if: matrix.test_file == 'test_modeling.py'
3644
uses: jlumbroso/free-disk-space@main
3745

3846
- name: Checkout code
@@ -50,11 +58,12 @@ jobs:
5058
pip install .[tests,onnxruntime] diffusers
5159
5260
- name: Test with pytest (in series)
61+
if: matrix.test_file == 'test_modeling.py'
5362
run: |
54-
pytest tests/onnxruntime -m "run_in_series" --durations=0 -vvvv
63+
pytest tests/onnxruntime/test_modeling.py -m "run_in_series" --durations=0 -vvvv
5564
5665
- name: Test with pytest (in parallel)
5766
run: |
58-
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -n auto
67+
pytest tests/onnxruntime/${{ matrix.test_file }} -m "not run_in_series" --durations=0 -vvvv -n auto
5968
env:
60-
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
69+
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}

optimum/exporters/onnx/__main__.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525

2626
from ...commands.export.onnx import parse_args_onnx
2727
from ...utils import DEFAULT_DUMMY_SHAPES, logging
28-
from ...utils.import_utils import is_transformers_version
28+
from ...utils.import_utils import (
29+
is_diffusers_available,
30+
is_sentence_transformers_available,
31+
is_timm_available,
32+
is_transformers_version,
33+
)
2934
from ...utils.save_utils import maybe_load_preprocessors
3035
from ..tasks import TasksManager
3136
from ..utils import DisableCompileContextManager
@@ -223,12 +228,29 @@ def main_export(
223228
" and passing it is not required anymore."
224229
)
225230

226-
if task in ["stable-diffusion", "stable-diffusion-xl"]:
231+
if task in ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]:
227232
logger.warning(
228233
f"The task `{task}` is deprecated and will be removed in a future release of Optimum. "
229234
"Please use one of the following tasks instead: `text-to-image`, `image-to-image`, `inpainting`."
230235
)
231236

237+
if library_name == "sentence_transformers" and not is_sentence_transformers_available():
238+
raise ImportError(
239+
"The library `sentence_transformers` was specified, but it is not installed. "
240+
"Please install it with `pip install sentence-transformers`."
241+
)
242+
243+
if library_name == "diffusers" and not is_diffusers_available():
244+
raise ImportError(
245+
"The library `diffusers` was specified, but it is not installed. "
246+
"Please install it with `pip install diffusers`."
247+
)
248+
249+
if library_name == "timm" and not is_timm_available():
250+
raise ImportError(
251+
"The library `timm` was specified, but it is not installed. Please install it with `pip install timm`."
252+
)
253+
232254
original_task = task
233255
task = TasksManager.map_from_synonym(task)
234256

@@ -241,6 +263,22 @@ def main_export(
241263
library_name = TasksManager.infer_library_from_model(
242264
model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
243265
)
266+
if library_name == "sentence_transformers" and not is_sentence_transformers_available():
267+
logger.warning(
268+
"The library name was inferred as `sentence_transformers`, which is not installed. "
269+
"Falling back to `transformers` to avoid breaking the export."
270+
)
271+
library_name = "transformers"
272+
elif library_name == "timm" and not is_timm_available():
273+
raise ImportError(
274+
"The library name was inferred as `timm`, which is not installed. "
275+
"Please install it with `pip install timm`."
276+
)
277+
elif library_name == "diffusers" and not is_diffusers_available():
278+
raise ImportError(
279+
"The library name was inferred as `diffusers`, which is not installed. "
280+
"Please install it with `pip install diffusers`."
281+
)
244282

245283
torch_dtype = None
246284
if framework == "pt":
@@ -258,7 +296,14 @@ def main_export(
258296

259297
if task == "auto":
260298
try:
261-
task = TasksManager.infer_task_from_model(model_name_or_path, library_name=library_name)
299+
task = TasksManager.infer_task_from_model(
300+
model_name_or_path,
301+
subfolder=subfolder,
302+
revision=revision,
303+
cache_dir=cache_dir,
304+
token=token,
305+
library_name=library_name,
306+
)
262307
except KeyError as e:
263308
raise KeyError(
264309
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
@@ -299,8 +344,9 @@ def main_export(
299344
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
300345
)
301346

302-
# TODO: Fix in Transformers so that SdpaAttention class can be exported to ONNX. `attn_implementation` is introduced in Transformers 4.36.
303-
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version(">=", "4.35.99"):
347+
# TODO: Fix in Transformers so that SdpaAttention class can be exported to ONNX.
348+
# This was fixed in transformers 4.42.0, we can remve it when minimum transformers version is updated to 4.42
349+
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version("<", "4.42"):
304350
loading_kwargs["attn_implementation"] = "eager"
305351

306352
with DisableCompileContextManager():

optimum/exporters/onnx/model_configs.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
"""Model specific ONNX configurations."""
16+
1617
import math
1718
import random
1819
import warnings
@@ -1337,45 +1338,51 @@ class VaeEncoderOnnxConfig(VisionOnnxConfig):
13371338
DEFAULT_ONNX_OPSET = 14
13381339

13391340
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
1340-
num_channels="in_channels",
1341-
image_size="sample_size",
1342-
allow_new=True,
1341+
num_channels="in_channels", image_size="sample_size", allow_new=True
13431342
)
13441343

13451344
@property
13461345
def inputs(self) -> Dict[str, Dict[int, str]]:
13471346
return {
1348-
"sample": {0: "batch_size", 2: "height", 3: "width"},
1347+
"sample": {0: "batch_size", 2: "sample_height", 3: "sample_width"},
13491348
}
13501349

13511350
@property
13521351
def outputs(self) -> Dict[str, Dict[int, str]]:
1352+
down_sampling_factor = 2 ** (len(self._normalized_config.down_block_types) - 1)
13531353
return {
1354-
"latent_parameters": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
1354+
"latent_parameters": {
1355+
0: "batch_size",
1356+
2: f"sample_height / {down_sampling_factor}",
1357+
3: f"sample_width / {down_sampling_factor}",
1358+
},
13551359
}
13561360

13571361

13581362
class VaeDecoderOnnxConfig(VisionOnnxConfig):
1359-
ATOL_FOR_VALIDATION = 1e-4
1363+
ATOL_FOR_VALIDATION = 3e-4
13601364
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
13611365
# operator support, available since opset 14
13621366
DEFAULT_ONNX_OPSET = 14
13631367

1364-
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
1365-
num_channels="latent_channels",
1366-
allow_new=True,
1367-
)
1368+
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(num_channels="latent_channels", allow_new=True)
13681369

13691370
@property
13701371
def inputs(self) -> Dict[str, Dict[int, str]]:
13711372
return {
1372-
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
1373+
"latent_sample": {0: "batch_size", 2: "latent_height", 3: "latent_width"},
13731374
}
13741375

13751376
@property
13761377
def outputs(self) -> Dict[str, Dict[int, str]]:
1378+
upsampling_factor = 2 ** (len(self._normalized_config.up_block_types) - 1)
1379+
13771380
return {
1378-
"sample": {0: "batch_size", 2: "height", 3: "width"},
1381+
"sample": {
1382+
0: "batch_size",
1383+
2: f"latent_height * {upsampling_factor}",
1384+
3: f"latent_width * {upsampling_factor}",
1385+
},
13791386
}
13801387

13811388

@@ -1815,9 +1822,17 @@ class MusicgenOnnxConfig(OnnxSeq2SeqConfigWithPast):
18151822
DEFAULT_ONNX_OPSET = 14
18161823

18171824
VARIANTS = {
1818-
"text-conditional-with-past": "Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:\n\t\t* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.\n\t\t* encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.\n\t\t* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).\n\t\t* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).\n\t\t* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.\n\t\t* build_delay_pattern_mask.onnx: A model taking as input `input_ids`, `pad_token_id`, `max_length`, and building a delayed pattern mask to the input_ids. Implements https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/musicgen/modeling_musicgen.py#L1054.",
1825+
"text-conditional-with-past": """Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation).
1826+
This uses the decoder KV cache. The following subcomponents are exported:
1827+
* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.
1828+
* encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.
1829+
* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).
1830+
* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).
1831+
* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.
1832+
* build_delay_pattern_mask.onnx: A model taking as input `input_ids`, `pad_token_id`, `max_length`, and building a delayed pattern mask to the input_ids. Implements https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/musicgen/modeling_musicgen.py#L1054.""",
18191833
}
1820-
# TODO: support audio-prompted generation (- audio_encoder_encode.onnx: corresponds to the audio encoder part in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.\n\t)
1834+
# TODO: support audio-prompted generation (audio_encoder_encode.onnx: corresponds to the audio encoder part
1835+
# in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.)
18211836
# With that, we have full Encodec support.
18221837
DEFAULT_VARIANT = "text-conditional-with-past"
18231838

optimum/exporters/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,7 @@ def infer_task_from_model(
19351935
token=token,
19361936
library_name=library_name,
19371937
)
1938-
elif type(model) == type:
1938+
elif type(model) is type:
19391939
inferred_task_name = cls._infer_task_from_model_or_model_class(model_class=model)
19401940
else:
19411941
inferred_task_name = cls._infer_task_from_model_or_model_class(model=model)
@@ -2089,7 +2089,7 @@ def infer_library_from_model(
20892089
cache_dir=cache_dir,
20902090
token=token,
20912091
)
2092-
elif type(model) == type:
2092+
elif type(model) is type:
20932093
library_name = cls._infer_library_from_model_or_model_class(model_class=model)
20942094
else:
20952095
library_name = cls._infer_library_from_model_or_model_class(model=model)

optimum/modeling_base.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import warnings
2121
from abc import ABC, abstractmethod
2222
from pathlib import Path
23-
from typing import TYPE_CHECKING, Optional, Union
23+
from typing import TYPE_CHECKING, List, Optional, Union
2424

2525
from huggingface_hub import HfApi
2626
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@@ -32,7 +32,17 @@
3232

3333

3434
if TYPE_CHECKING:
35-
from transformers import PreTrainedModel, TFPreTrainedModel
35+
from transformers import (
36+
FeatureExtractionMixin,
37+
ImageProcessingMixin,
38+
PreTrainedModel,
39+
ProcessorMixin,
40+
SpecialTokensMixin,
41+
TFPreTrainedModel,
42+
)
43+
44+
PreprocessorT = Union[SpecialTokensMixin, FeatureExtractionMixin, ImageProcessingMixin, ProcessorMixin]
45+
ModelT = Union["PreTrainedModel", "TFPreTrainedModel"]
3646

3747

3848
logger = logging.getLogger(__name__)
@@ -79,6 +89,7 @@
7989
"""
8090

8191

92+
# TODO: Should be removed when we no longer use OptimizedModel for everything
8293
# workaround to enable compatibility between optimum models and transformers pipelines
8394
class PreTrainedModel(ABC): # noqa: F811
8495
pass
@@ -89,11 +100,12 @@ class OptimizedModel(PreTrainedModel):
89100
base_model_prefix = "optimized_model"
90101
config_name = CONFIG_NAME
91102

92-
def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: PretrainedConfig):
93-
super().__init__()
103+
def __init__(
104+
self, model: Union["ModelT"], config: "PretrainedConfig", preprocessors: Optional[List["PreprocessorT"]] = None
105+
):
94106
self.model = model
95107
self.config = config
96-
self.preprocessors = []
108+
self.preprocessors = preprocessors or []
97109

98110
def __call__(self, *args, **kwargs):
99111
return self.forward(*args, **kwargs)
@@ -291,27 +303,6 @@ def _from_pretrained(
291303
"""Overwrite this method in subclass to define how to load your model from pretrained"""
292304
raise NotImplementedError("Overwrite this method in subclass to define how to load your model from pretrained")
293305

294-
@classmethod
295-
def _from_transformers(
296-
cls,
297-
model_id: Union[str, Path],
298-
config: PretrainedConfig,
299-
use_auth_token: Optional[Union[bool, str]] = None,
300-
token: Optional[Union[bool, str]] = None,
301-
revision: Optional[str] = None,
302-
force_download: bool = False,
303-
cache_dir: str = HUGGINGFACE_HUB_CACHE,
304-
subfolder: str = "",
305-
local_files_only: bool = False,
306-
trust_remote_code: bool = False,
307-
**kwargs,
308-
) -> "OptimizedModel":
309-
"""Overwrite this method in subclass to define how to load your model from vanilla transformers model"""
310-
raise NotImplementedError(
311-
"`_from_transformers` method will be deprecated in a future release. Please override `_export` instead"
312-
"to define how to load your model from vanilla transformers model"
313-
)
314-
315306
@classmethod
316307
def _export(
317308
cls,
@@ -366,13 +357,6 @@ def from_pretrained(
366357
if isinstance(model_id, Path):
367358
model_id = model_id.as_posix()
368359

369-
from_transformers = kwargs.pop("from_transformers", None)
370-
if from_transformers is not None:
371-
logger.warning(
372-
"The argument `from_transformers` is deprecated, and will be removed in optimum 2.0. Use `export` instead"
373-
)
374-
export = from_transformers
375-
376360
if len(model_id.split("@")) == 2:
377361
logger.warning(
378362
f"Specifying the `revision` as @{model_id.split('@')[1]} is deprecated and will be removed in v1.23, please use the `revision` argument instead."
@@ -436,7 +420,7 @@ def from_pretrained(
436420
trust_remote_code=trust_remote_code,
437421
)
438422

439-
from_pretrained_method = cls._from_transformers if export else cls._from_pretrained
423+
from_pretrained_method = cls._export if export else cls._from_pretrained
440424

441425
return from_pretrained_method(
442426
model_id=model_id,

0 commit comments

Comments
 (0)