1
- from array_api_compat import (is_numpy_array , is_cupy_array , is_torch_array , # noqa: F401
2
- is_dask_array , is_jax_array , is_pydata_sparse_array )
1
+ from array_api_compat import ( # noqa: F401
2
+ is_numpy_array , is_cupy_array , is_torch_array ,
3
+ is_dask_array , is_jax_array , is_pydata_sparse_array ,
4
+ is_numpy_namespace , is_cupy_namespace , is_torch_namespace ,
5
+ is_dask_namespace , is_jax_namespace , is_pydata_sparse_namespace ,
6
+ )
3
7
4
8
from array_api_compat import is_array_api_obj , device , to_device
5
9
10
14
import array
11
15
from numpy .testing import assert_allclose
12
16
13
- is_functions = {
17
+ is_array_functions = {
14
18
'numpy' : 'is_numpy_array' ,
15
19
'cupy' : 'is_cupy_array' ,
16
20
'torch' : 'is_torch_array' ,
19
23
'sparse' : 'is_pydata_sparse_array' ,
20
24
}
21
25
22
- @pytest .mark .parametrize ('library' , is_functions .keys ())
23
- @pytest .mark .parametrize ('func' , is_functions .values ())
26
+ is_namespace_functions = {
27
+ 'numpy' : 'is_numpy_namespace' ,
28
+ 'cupy' : 'is_cupy_namespace' ,
29
+ 'torch' : 'is_torch_namespace' ,
30
+ 'dask.array' : 'is_dask_namespace' ,
31
+ 'jax.numpy' : 'is_jax_namespace' ,
32
+ 'sparse' : 'is_pydata_sparse_namespace' ,
33
+ }
34
+
35
+
36
+ @pytest .mark .parametrize ('library' , is_array_functions .keys ())
37
+ @pytest .mark .parametrize ('func' , is_array_functions .values ())
24
38
def test_is_xp_array (library , func ):
25
39
lib = import_ (library )
26
40
is_func = globals ()[func ]
27
41
28
42
x = lib .asarray ([1 , 2 , 3 ])
29
43
30
- assert is_func (x ) == (func == is_functions [library ])
44
+ assert is_func (x ) == (func == is_array_functions [library ])
31
45
32
46
assert is_array_api_obj (x )
33
47
48
+
49
+ @pytest .mark .parametrize ('library' , is_namespace_functions .keys ())
50
+ @pytest .mark .parametrize ('func' , is_namespace_functions .values ())
51
+ def test_is_xp_namespace (library , func ):
52
+ lib = import_ (library )
53
+ is_func = globals ()[func ]
54
+
55
+ assert is_func (lib ) == (func == is_namespace_functions [library ])
56
+
57
+
34
58
@pytest .mark .parametrize ("library" , all_libraries )
35
59
def test_device (library ):
36
60
xp = import_ (library , wrapper = True )
@@ -64,8 +88,8 @@ def test_to_device_host(library):
64
88
assert_allclose (x , expected )
65
89
66
90
67
- @pytest .mark .parametrize ("target_library" , is_functions .keys ())
68
- @pytest .mark .parametrize ("source_library" , is_functions .keys ())
91
+ @pytest .mark .parametrize ("target_library" , is_array_functions .keys ())
92
+ @pytest .mark .parametrize ("source_library" , is_array_functions .keys ())
69
93
def test_asarray_cross_library (source_library , target_library , request ):
70
94
if source_library == "dask.array" and target_library == "torch" :
71
95
# Allow rest of test to execute instead of immediately xfailing
@@ -81,7 +105,7 @@ def test_asarray_cross_library(source_library, target_library, request):
81
105
pytest .skip (reason = "`sparse` does not allow implicit densification" )
82
106
src_lib = import_ (source_library , wrapper = True )
83
107
tgt_lib = import_ (target_library , wrapper = True )
84
- is_tgt_type = globals ()[is_functions [target_library ]]
108
+ is_tgt_type = globals ()[is_array_functions [target_library ]]
85
109
86
110
a = src_lib .asarray ([1 , 2 , 3 ])
87
111
b = tgt_lib .asarray (a )
@@ -96,7 +120,7 @@ def test_asarray_copy(library):
96
120
# should be able to delete this.
97
121
xp = import_ (library , wrapper = True )
98
122
asarray = xp .asarray
99
- is_lib_func = globals ()[is_functions [library ]]
123
+ is_lib_func = globals ()[is_array_functions [library ]]
100
124
all = xp .all if library != 'dask.array' else lambda x : xp .all (x ).compute ()
101
125
102
126
if library == 'numpy' and xp .__version__ [0 ] < '2' and not hasattr (xp , '_CopyMode' ) :
0 commit comments