Skip to content

Commit 2c0476e

Browse files
authored
Enable ONNX export for transformers 4.45 (#2045)
* Enable ONNX export for transformers 4.45 * add comment * update setup
1 parent d3c56cd commit 2c0476e

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

optimum/exporters/onnx/convert.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import numpy as np
2828
import onnx
29-
import transformers
3029
from transformers.modeling_utils import get_parameter_dtype
3130
from transformers.utils import is_tf_available, is_torch_available
3231

@@ -531,6 +530,11 @@ def export_pytorch(
531530
logger.info(f"Using framework PyTorch: {torch.__version__}")
532531
FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"
533532

533+
model_kwargs = model_kwargs or {}
534+
# num_logits_to_keep was added in transformers 4.45 and isn't added as inputs when exporting the model
535+
if check_if_transformers_greater("4.44.99") and "num_logits_to_keep" in signature(model.forward).parameters.keys():
536+
model_kwargs["num_logits_to_keep"] = 0
537+
534538
with torch.no_grad():
535539
model.config.return_dict = True
536540
model = model.eval()
@@ -1001,11 +1005,6 @@ def onnx_export_from_model(
10011005
>>> onnx_export_from_model(model, output="gpt2_onnx/")
10021006
```
10031007
"""
1004-
if check_if_transformers_greater("4.44.99"):
1005-
raise ImportError(
1006-
f"ONNX conversion disabled for now for transformers version greater than v4.45, found {transformers.__version__}"
1007-
)
1008-
10091008
TasksManager.standardize_model_attributes(model)
10101009

10111010
if hasattr(model.config, "export_model_type"):

setup.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
"datasets>=1.2.1",
5555
"evaluate",
5656
"protobuf>=3.20.1",
57-
"transformers<4.45.0",
5857
],
5958
"onnxruntime-gpu": [
6059
"onnx",
@@ -63,10 +62,9 @@
6362
"evaluate",
6463
"protobuf>=3.20.1",
6564
"accelerate", # ORTTrainer requires it.
66-
"transformers<4.45.0",
6765
],
68-
"exporters": ["onnx", "onnxruntime", "timm", "transformers<4.45.0"],
69-
"exporters-gpu": ["onnx", "onnxruntime-gpu", "timm", "transformers<4.45.0"],
66+
"exporters": ["onnx", "onnxruntime", "timm"],
67+
"exporters-gpu": ["onnx", "onnxruntime-gpu", "timm"],
7068
"exporters-tf": [
7169
"tensorflow>=2.4,<=2.12.1",
7270
"tf2onnx",
@@ -77,7 +75,6 @@
7775
"numpy<1.24.0",
7876
"datasets<=2.16",
7977
"transformers[sentencepiece]>=4.26,<4.38",
80-
"transformers<4.45.0",
8178
],
8279
"diffusers": ["diffusers"],
8380
"intel": "optimum-intel>=1.18.0",

0 commit comments

Comments
 (0)