Skip to content

Commit 141c078

Browse files
ogriselMarcBresson
authored andcommitted
Drop support for the redundant and deprecated cupy.array_api in favor of array_api_compat. (scikit-learn#29639)
1 parent d4519f4 commit 141c078

File tree

9 files changed

+32
-141
lines changed

9 files changed

+32
-141
lines changed

.github/workflows/cuda-ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ jobs:
4242
run: |
4343
source "${HOME}/conda/etc/profile.d/conda.sh"
4444
conda activate sklearn
45+
python -c "import sklearn; sklearn.show_versions()"
4546
SCIPY_ARRAY_API=1 pytest -k 'array_api'

build_tools/github/create_gpu_environment.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@ conda activate base
1515
CONDA_ENV_NAME=sklearn
1616
LOCK_FILE=build_tools/github/pylatest_conda_forge_cuda_array-api_linux-64_conda.lock
1717
create_conda_environment_from_lock_file $CONDA_ENV_NAME $LOCK_FILE
18+
19+
conda activate $CONDA_ENV_NAME
20+
conda list

doc/modules/array_api.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ At this stage, this support is **considered experimental** and must be enabled
2121
explicitly as explained in the following.
2222

2323
.. note::
24-
Currently, only `cupy.array_api`, `array-api-strict`, `cupy`, and `PyTorch`
25-
are known to work with scikit-learn's estimators.
24+
Currently, only `array-api-strict`, `cupy`, and `PyTorch` are known to work
25+
with scikit-learn's estimators.
2626

2727
Example usage
2828
=============

doc/whats_new/v1.6.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ See :ref:`array_api` for more details.
6262
compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim
6363
Head <betatim>` and :user:`Olivier Grisel <ogrisel>`.
6464

65+
**Other**
66+
67+
- Support for the soon to be deprecated `cupy.array_api` module has been
68+
removed in favor of directly supporting the top level `cupy` module, possibly
69+
via the `array_api_compat.cupy` compatibility wrapper. :pr:`29639` by
70+
:user:`Olivier Grisel <ogrisel>`.
71+
6572
Metadata Routing
6673
----------------
6774

sklearn/metrics/pairwise.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
gen_even_slices,
2323
)
2424
from ..utils._array_api import (
25-
_clip,
2625
_fill_or_add_to_diagonal,
2726
_find_matching_floating_dtype,
2827
_is_numpy_namespace,
2928
_max_precision_float_dtype,
3029
_modify_in_place_if_numpy,
30+
device,
3131
get_namespace,
3232
get_namespace_and_device,
3333
)
@@ -1166,7 +1166,10 @@ def cosine_distances(X, Y=None):
11661166
S = cosine_similarity(X, Y)
11671167
S *= -1
11681168
S += 1
1169-
S = _clip(S, 0, 2, xp)
1169+
# TODO: remove the xp.asarray calls once the following is fixed:
1170+
# https://github.com/data-apis/array-api-compat/issues/177
1171+
device_ = device(S)
1172+
S = xp.clip(S, xp.asarray(0.0, device=device_), xp.asarray(2.0, device=device_))
11701173
if X is Y or Y is None:
11711174
# Ensure that distances between vectors and themselves are set to 0.0.
11721175
# This may not be the case due to floating point rounding errors.

sklearn/utils/_array_api.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def yield_namespaces(include_numpy_namespaces=True):
4343
# array_api_strict.Array instances always have a dummy "device" attribute.
4444
"array_api_strict",
4545
"cupy",
46-
"cupy.array_api",
4746
"torch",
4847
]:
4948
if not include_numpy_namespaces and array_namespace in _NUMPY_NAMESPACE_NAMES:
@@ -242,7 +241,7 @@ def _isdtype_single(dtype, kind, *, xp):
242241
elif kind == "real floating":
243242
return dtype in supported_float_dtypes(xp)
244243
elif kind == "complex floating":
245-
# Some name spaces do not have complex, such as cupy.array_api
244+
# Some name spaces might not have support for complex dtypes.
246245
complex_dtypes = set()
247246
if hasattr(xp, "complex64"):
248247
complex_dtypes.add(xp.complex64)
@@ -304,42 +303,6 @@ def ensure_common_namespace_device(reference, *arrays):
304303
return arrays
305304

306305

307-
class _ArrayAPIWrapper:
308-
"""sklearn specific Array API compatibility wrapper
309-
310-
This wrapper makes it possible for scikit-learn maintainers to
311-
deal with discrepancies between different implementations of the
312-
Python Array API standard and its evolution over time.
313-
314-
The Python Array API standard specification:
315-
https://data-apis.org/array-api/latest/
316-
317-
Documentation of the NumPy implementation:
318-
https://numpy.org/neps/nep-0047-array-api-standard.html
319-
"""
320-
321-
def __init__(self, array_namespace):
322-
self._namespace = array_namespace
323-
324-
def __getattr__(self, name):
325-
return getattr(self._namespace, name)
326-
327-
def __eq__(self, other):
328-
return self._namespace == other._namespace
329-
330-
def isdtype(self, dtype, kind):
331-
return isdtype(dtype, kind, xp=self._namespace)
332-
333-
def maximum(self, x1, x2):
334-
# TODO: Remove when `maximum` is made compatible in `array_api_compat`,
335-
# based on the `2023.12` specification.
336-
# https://github.com/data-apis/array-api-compat/issues/127
337-
x1_np = _convert_to_numpy(x1, xp=self._namespace)
338-
x2_np = _convert_to_numpy(x2, xp=self._namespace)
339-
x_max = numpy.maximum(x1_np, x2_np)
340-
return self._namespace.asarray(x_max, device=device(x1, x2))
341-
342-
343306
def _check_device_cpu(device): # noqa
344307
if device not in {"cpu", None}:
345308
raise ValueError(f"Unsupported device for NumPy: {device!r}")
@@ -597,11 +560,6 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
597560

598561
namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True
599562

600-
# These namespaces need additional wrapping to smooth out small differences
601-
# between implementations
602-
if namespace.__name__ in {"cupy.array_api"}:
603-
namespace = _ArrayAPIWrapper(namespace)
604-
605563
if namespace.__name__ == "array_api_strict" and hasattr(
606564
namespace, "set_array_api_strict_flags"
607565
):
@@ -827,19 +785,6 @@ def _nanmax(X, axis=None, xp=None):
827785
return X
828786

829787

830-
def _clip(S, min_val, max_val, xp):
831-
# TODO: remove this method and change all usage once we move to array api 2023.12
832-
# https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.clip.html#clip
833-
if _is_numpy_namespace(xp):
834-
return numpy.clip(S, min_val, max_val)
835-
else:
836-
min_arr = xp.asarray(min_val, dtype=S.dtype)
837-
max_arr = xp.asarray(max_val, dtype=S.dtype)
838-
S = xp.where(S < min_arr, min_arr, S)
839-
S = xp.where(S > max_arr, max_arr, S)
840-
return S
841-
842-
843788
def _asarray_with_order(
844789
array, dtype=None, order=None, copy=None, *, xp=None, device=None
845790
):
@@ -890,8 +835,6 @@ def _convert_to_numpy(array, xp):
890835

891836
if xp_name in {"array_api_compat.torch", "torch"}:
892837
return array.cpu().numpy()
893-
elif xp_name == "cupy.array_api":
894-
return array._array.get()
895838
elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover
896839
return array.get()
897840

sklearn/utils/_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ def _array_api_for_tests(array_namespace, device):
10241024
"MPS is not available because the current PyTorch install was not "
10251025
"built with MPS enabled."
10261026
)
1027-
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
1027+
elif array_namespace == "cupy": # pragma: nocover
10281028
import cupy
10291029

10301030
if cupy.cuda.runtime.getDeviceCount() == 0:

sklearn/utils/tests/test_array_api.py

Lines changed: 10 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sklearn._config import config_context
1010
from sklearn.base import BaseEstimator
1111
from sklearn.utils._array_api import (
12-
_ArrayAPIWrapper,
1312
_asarray_with_order,
1413
_atol_for_type,
1514
_average,
@@ -104,48 +103,6 @@ def mock_getenv(key):
104103
xp_out, is_array_api_compliant = get_namespace(X_xp)
105104

106105

107-
class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
108-
"""API wrapper that has an adjustable name. Used for testing."""
109-
110-
def __init__(self, array_namespace, name):
111-
super().__init__(array_namespace=array_namespace)
112-
self.__name__ = name
113-
114-
115-
def test_array_api_wrapper_astype():
116-
"""Test _ArrayAPIWrapper for ArrayAPIs that is not NumPy."""
117-
array_api_strict = pytest.importorskip("array_api_strict")
118-
xp_ = _AdjustableNameAPITestWrapper(array_api_strict, "array_api_strict")
119-
xp = _ArrayAPIWrapper(xp_)
120-
121-
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
122-
X_converted = xp.astype(X, xp.float32)
123-
assert X_converted.dtype == xp.float32
124-
125-
X_converted = xp.asarray(X, dtype=xp.float32)
126-
assert X_converted.dtype == xp.float32
127-
128-
129-
def test_array_api_wrapper_maximum():
130-
"""Test _ArrayAPIWrapper `maximum` for ArrayAPIs other than NumPy.
131-
132-
This is mainly used to test for `cupy.array_api` but since that is
133-
not available on our coverage-enabled PR CI, we resort to using
134-
`array-api-strict`.
135-
"""
136-
array_api_strict = pytest.importorskip("array_api_strict")
137-
xp_ = _AdjustableNameAPITestWrapper(array_api_strict, "array_api_strict")
138-
xp = _ArrayAPIWrapper(xp_)
139-
140-
x1 = xp.asarray(([[1, 2, 3], [3, 9, 5]]), dtype=xp.int64)
141-
x2 = xp.asarray(([[0, 1, 6], [8, 4, 5]]), dtype=xp.int64)
142-
result = xp.asarray([[1, 2, 6], [8, 9, 5]], dtype=xp.int64)
143-
144-
x_max = xp.maximum(x1, x2)
145-
assert x_max.dtype == x1.dtype
146-
assert xp.all(xp.equal(x_max, result))
147-
148-
149106
@pytest.mark.parametrize("array_api", ["numpy", "array_api_strict"])
150107
def test_asarray_with_order(array_api):
151108
"""Test _asarray_with_order passes along order for NumPy arrays."""
@@ -158,21 +115,6 @@ def test_asarray_with_order(array_api):
158115
assert X_new_np.flags["F_CONTIGUOUS"]
159116

160117

161-
def test_asarray_with_order_ignored():
162-
"""Test _asarray_with_order ignores order for Generic ArrayAPI."""
163-
xp = pytest.importorskip("array_api_strict")
164-
xp_ = _AdjustableNameAPITestWrapper(xp, "array_api_strict")
165-
166-
X = numpy.asarray([[1.2, 3.4, 5.1], [3.4, 5.5, 1.2]], order="C")
167-
X = xp_.asarray(X)
168-
169-
X_new = _asarray_with_order(X, order="F", xp=xp_)
170-
171-
X_new_np = numpy.asarray(X_new)
172-
assert X_new_np.flags["C_CONTIGUOUS"]
173-
assert not X_new_np.flags["F_CONTIGUOUS"]
174-
175-
176118
@pytest.mark.parametrize(
177119
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
178120
)
@@ -351,8 +293,8 @@ def __init__(self, device_name):
351293
assert array1.device == device(array1, array1, array2)
352294

353295

354-
# TODO: add cupy and cupy.array_api to the list of libraries once the
355-
# the following upstream issue has been fixed:
296+
# TODO: add cupy to the list of libraries once the the following upstream issue
297+
# has been fixed:
356298
# https://github.com/cupy/cupy/issues/8180
357299
@skip_if_array_api_compat_not_configured
358300
@pytest.mark.parametrize("library", ["numpy", "array_api_strict", "torch"])
@@ -419,7 +361,7 @@ def test_ravel(namespace, _device, _dtype):
419361

420362

421363
@skip_if_array_api_compat_not_configured
422-
@pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"])
364+
@pytest.mark.parametrize("library", ["cupy", "torch"])
423365
def test_convert_to_numpy_gpu(library): # pragma: nocover
424366
"""Check convert_to_numpy for GPU backed libraries."""
425367
xp = pytest.importorskip(library)
@@ -459,7 +401,7 @@ def fit(self, X, y=None):
459401
[
460402
("torch", lambda array: array.cpu().numpy()),
461403
("array_api_strict", lambda array: numpy.asarray(array)),
462-
("cupy.array_api", lambda array: array._array.get()),
404+
("cupy", lambda array: array.get()),
463405
],
464406
)
465407
def test_convert_estimator_to_ndarray(array_namespace, converter):
@@ -500,15 +442,9 @@ def test_reshape_behavior():
500442
xp.reshape(X, -1)
501443

502444

503-
@pytest.mark.parametrize("wrapper", [_ArrayAPIWrapper, _NumPyAPIWrapper])
504-
def test_get_namespace_array_api_isdtype(wrapper):
505-
"""Test isdtype implementation from _ArrayAPIWrapper and _NumPyAPIWrapper."""
506-
507-
if wrapper == _ArrayAPIWrapper:
508-
xp_ = pytest.importorskip("array_api_strict")
509-
xp = _ArrayAPIWrapper(xp_)
510-
else:
511-
xp = _NumPyAPIWrapper()
445+
def test_get_namespace_array_api_isdtype():
446+
"""Test isdtype implementation from _NumPyAPIWrapper."""
447+
xp = _NumPyAPIWrapper()
512448

513449
assert xp.isdtype(xp.float32, xp.float32)
514450
assert xp.isdtype(xp.float32, "real floating")
@@ -533,10 +469,9 @@ def test_get_namespace_array_api_isdtype(wrapper):
533469

534470
assert not xp.isdtype(xp.float32, "complex floating")
535471

536-
if wrapper == _NumPyAPIWrapper:
537-
assert not xp.isdtype(xp.int8, "complex floating")
538-
assert xp.isdtype(xp.complex64, "complex floating")
539-
assert xp.isdtype(xp.complex128, "complex floating")
472+
assert not xp.isdtype(xp.int8, "complex floating")
473+
assert xp.isdtype(xp.complex64, "complex floating")
474+
assert xp.isdtype(xp.complex128, "complex floating")
540475

541476
with pytest.raises(ValueError, match="Unrecognized data type"):
542477
assert xp.isdtype(xp.int16, "unknown")

sklearn/utils/tests/test_validation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,10 +2011,9 @@ def test_pandas_array_returns_ndarray(input_values):
20112011

20122012

20132013
@skip_if_array_api_compat_not_configured
2014-
@pytest.mark.parametrize("array_namespace", ["array_api_strict", "cupy.array_api"])
2015-
def test_check_array_array_api_has_non_finite(array_namespace):
2014+
def test_check_array_array_api_has_non_finite():
20162015
"""Checks that Array API arrays checks non-finite correctly."""
2017-
xp = pytest.importorskip(array_namespace)
2016+
xp = pytest.importorskip("array_api_strict")
20182017

20192018
X_nan = xp.asarray([[xp.nan, 1, 0], [0, xp.nan, 3]], dtype=xp.float32)
20202019
with config_context(array_api_dispatch=True):

0 commit comments

Comments
 (0)