10
10
import array
11
11
from numpy .testing import assert_allclose
12
12
13
- is_functions = {
13
+ is_array_functions = {
14
14
'numpy' : 'is_numpy_array' ,
15
15
'cupy' : 'is_cupy_array' ,
16
16
'torch' : 'is_torch_array' ,
19
19
'sparse' : 'is_pydata_sparse_array' ,
20
20
}
21
21
22
- @pytest .mark .parametrize ('library' , is_functions .keys ())
23
- @pytest .mark .parametrize ('func' , is_functions .values ())
22
+ is_namespace_functions = {
23
+ 'numpy' : 'is_numpy_namespace' ,
24
+ 'cupy' : 'is_cupy_namespace' ,
25
+ 'torch' : 'is_torch_namespace' ,
26
+ 'dask.array' : 'is_dask_namespace' ,
27
+ 'jax.numpy' : 'is_jax_namespace' ,
28
+ 'sparse' : 'is_pydata_sparse_namespace' ,
29
+ }
30
+
31
+
32
+ @pytest .mark .parametrize ('library' , is_array_functions .keys ())
33
+ @pytest .mark .parametrize ('func' , is_array_functions .values ())
24
34
def test_is_xp_array (library , func ):
25
35
lib = import_ (library )
26
36
is_func = globals ()[func ]
27
37
28
38
x = lib .asarray ([1 , 2 , 3 ])
29
39
30
- assert is_func (x ) == (func == is_functions [library ])
40
+ assert is_func (x ) == (func == is_array_functions [library ])
31
41
32
42
assert is_array_api_obj (x )
33
43
44
+
45
+ @pytest .mark .parametrize ('library' , is_namespace_functions .keys ())
46
+ @pytest .mark .parametrize ('func' , is_namespace_functions .values ())
47
+ def test_is_xp_namespace (library , func ):
48
+ lib = import_ (library )
49
+ is_func = globals ()[func ]
50
+
51
+ assert is_func (lib ) == (func == is_namespace_functions [library ])
52
+
53
+
34
54
@pytest .mark .parametrize ("library" , all_libraries )
35
55
def test_device (library ):
36
56
xp = import_ (library , wrapper = True )
@@ -64,8 +84,8 @@ def test_to_device_host(library):
64
84
assert_allclose (x , expected )
65
85
66
86
67
- @pytest .mark .parametrize ("target_library" , is_functions .keys ())
68
- @pytest .mark .parametrize ("source_library" , is_functions .keys ())
87
+ @pytest .mark .parametrize ("target_library" , is_array_functions .keys ())
88
+ @pytest .mark .parametrize ("source_library" , is_array_functions .keys ())
69
89
def test_asarray_cross_library (source_library , target_library , request ):
70
90
if source_library == "dask.array" and target_library == "torch" :
71
91
# Allow rest of test to execute instead of immediately xfailing
@@ -81,7 +101,7 @@ def test_asarray_cross_library(source_library, target_library, request):
81
101
pytest .skip (reason = "`sparse` does not allow implicit densification" )
82
102
src_lib = import_ (source_library , wrapper = True )
83
103
tgt_lib = import_ (target_library , wrapper = True )
84
- is_tgt_type = globals ()[is_functions [target_library ]]
104
+ is_tgt_type = globals ()[is_array_functions [target_library ]]
85
105
86
106
a = src_lib .asarray ([1 , 2 , 3 ])
87
107
b = tgt_lib .asarray (a )
@@ -96,7 +116,7 @@ def test_asarray_copy(library):
96
116
# should be able to delete this.
97
117
xp = import_ (library , wrapper = True )
98
118
asarray = xp .asarray
99
- is_lib_func = globals ()[is_functions [library ]]
119
+ is_lib_func = globals ()[is_array_functions [library ]]
100
120
all = xp .all if library != 'dask.array' else lambda x : xp .all (x ).compute ()
101
121
102
122
if library == 'numpy' and xp .__version__ [0 ] < '2' and not hasattr (xp , '_CopyMode' ) :
0 commit comments