File tree 2 files changed +20
-1
lines changed
2 files changed +20
-1
lines changed Original file line number Diff line number Diff line change @@ -91,7 +91,7 @@ def is_cupy_array(x):
91
91
import cupy as cp
92
92
93
93
# TODO: Should we reject ndarray subclasses?
94
- return isinstance (x , ( cp .ndarray , cp . generic ) )
94
+ return isinstance (x , cp .ndarray )
95
95
96
96
def is_torch_array (x ):
97
97
"""
Original file line number Diff line number Diff line change @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
55
55
assert is_func (lib ) == (func == is_namespace_functions [library ])
56
56
57
57
58
+ @pytest .mark .parametrize ('library' , all_libraries )
59
+ def test_xp_is_array_generics (library ):
60
+ """
61
+ Test that scalar selection on a xp.ndarray always returns
62
+ an object that matches with exactly one among the is_*_array
63
+ function of the same library or with is_numpy_array.
64
+ """
65
+ lib = import_ (library )
66
+ x = lib .asarray ([1 , 2 , 3 ])
67
+ x0 = x [0 ]
68
+
69
+ matches = []
70
+ for library2 , func in is_array_functions .items ():
71
+ is_func = globals ()[func ]
72
+ if is_func (x0 ):
73
+ matches .append (library2 )
74
+ assert matches in ([library ], ["numpy" ])
75
+
76
+
58
77
@pytest .mark .parametrize ("library" , all_libraries )
59
78
def test_device (library ):
60
79
xp = import_ (library , wrapper = True )
You can’t perform that action at this time.
0 commit comments