Skip to content

Commit bbc0e9c

Browse files
authored
Merge pull request #207 from crusaderky/cupy_generic
BUG: cp.generic is an alias for np.generic
2 parents edde2e0 + 8b53a7e commit bbc0e9c

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

array_api_compat/common/_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def is_cupy_array(x):
9191
import cupy as cp
9292

9393
# TODO: Should we reject ndarray subclasses?
94-
return isinstance(x, (cp.ndarray, cp.generic))
94+
return isinstance(x, cp.ndarray)
9595

9696
def is_torch_array(x):
9797
"""

tests/test_common.py

+19
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
5555
assert is_func(lib) == (func == is_namespace_functions[library])
5656

5757

58+
@pytest.mark.parametrize('library', all_libraries)
59+
def test_xp_is_array_generics(library):
60+
"""
61+
Test that scalar selection on a xp.ndarray always returns
62+
an object that matches with exactly one among the is_*_array
63+
function of the same library and is_numpy_array.
64+
"""
65+
lib = import_(library)
66+
x = lib.asarray([1, 2, 3])
67+
x0 = x[0]
68+
69+
matches = []
70+
for library2, func in is_array_functions.items():
71+
is_func = globals()[func]
72+
if is_func(x0):
73+
matches.append(library2)
74+
assert matches in ([library], ["numpy"])
75+
76+
5877
@pytest.mark.parametrize("library", all_libraries)
5978
def test_device(library):
6079
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)