Skip to content

Commit f1b708c

Browse files
authored
Fixes detection of CuPy installed with pre-built wheels (#1965)
The CuPy library ships both a source distribution (`cupy`) as well as versions containing pre-built wheels (`cupy-cuda11x`, `cupy-cuda12x`, `cupy-rocm-5-0`, `cupy-rocm-4-3`). Use of `_is_package_available` to detect CuPy only works for the source distribution of CuPy and fails when using the pre-built wheels versions. This is because the `_is_package_available` will always attempt to resolve version information (even if it's not required) and in doing so assumes that the _importable_ package name matches the _installed_ distribution name. While this is usually the case, it doesn't work for CuPy and several other libraries. ONNX Runtime for example might be installed as `onnxruntime` or `onnxruntime-gpu` and thus Optimum just uses `importlib.util.find_spec` to work around the same problem. This commit replicates the same solution for CuPy.
1 parent 26949f5 commit f1b708c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

optimum/onnxruntime/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Utility functions, classes and constants for ONNX Runtime."""
1515

16+
import importlib
1617
import os
1718
import re
1819
from enum import Enum
@@ -31,7 +32,6 @@
3132
import onnxruntime as ort
3233

3334
from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss
34-
from ..utils.import_utils import _is_package_available
3535

3636

3737
if TYPE_CHECKING:
@@ -91,9 +91,11 @@ def is_onnxruntime_training_available():
9191

9292
def is_cupy_available():
9393
"""
94-
Checks if onnxruntime-training is available.
94+
Checks if CuPy is available.
9595
"""
96-
return _is_package_available("cupy")
96+
# Don't use _is_package_available as it doesn't work with CuPy installed
97+
# with `cupy-cuda*` and `cupy-rocm-*` package name (prebuilt wheels).
98+
return importlib.util.find_spec("cupy") is not None
9799

98100

99101
class ORTConfigManager:

0 commit comments

Comments
 (0)