Skip to content

Commit 6280e63

Browse files
committed
TST: test is_*_namespace fns
1 parent 733d17c commit 6280e63

File tree

7 files changed

+69
-17
lines changed

7 files changed

+69
-17
lines changed

array_api_compat/common/_helpers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def is_numpy_namespace(xp) -> bool:
276276
is_pydata_sparse_namespace
277277
is_array_api_strict_namespace
278278
"""
279-
return xp.__name__ in {'numpy', _compat_module_name + '.numpy'}
279+
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280280

281281
def is_cupy_namespace(xp) -> bool:
282282
"""
@@ -296,7 +296,7 @@ def is_cupy_namespace(xp) -> bool:
296296
is_pydata_sparse_namespace
297297
is_array_api_strict_namespace
298298
"""
299-
return xp.__name__ in {'cupy', _compat_module_name + '.cupy'}
299+
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300300

301301
def is_torch_namespace(xp) -> bool:
302302
"""
@@ -316,7 +316,7 @@ def is_torch_namespace(xp) -> bool:
316316
is_pydata_sparse_namespace
317317
is_array_api_strict_namespace
318318
"""
319-
return xp.__name__ in {'torch', _compat_module_name + '.torch'}
319+
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320320

321321

322322
def is_ndonnx_namespace(xp):
@@ -355,7 +355,7 @@ def is_dask_namespace(xp):
355355
is_pydata_sparse_namespace
356356
is_array_api_strict_namespace
357357
"""
358-
return xp.__name__ in {'dask.array', _compat_module_name + '.dask.array'}
358+
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359359

360360
def is_jax_namespace(xp):
361361
"""

tests/test_common.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2-
is_dask_array, is_jax_array, is_pydata_sparse_array)
1+
from array_api_compat import ( # noqa: F401
2+
is_numpy_array, is_cupy_array, is_torch_array,
3+
is_dask_array, is_jax_array, is_pydata_sparse_array,
4+
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
5+
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
6+
)
37

48
from array_api_compat import is_array_api_obj, device, to_device
59

@@ -10,7 +14,7 @@
1014
import array
1115
from numpy.testing import assert_allclose
1216

13-
is_functions = {
17+
is_array_functions = {
1418
'numpy': 'is_numpy_array',
1519
'cupy': 'is_cupy_array',
1620
'torch': 'is_torch_array',
@@ -19,18 +23,38 @@
1923
'sparse': 'is_pydata_sparse_array',
2024
}
2125

22-
@pytest.mark.parametrize('library', is_functions.keys())
23-
@pytest.mark.parametrize('func', is_functions.values())
26+
is_namespace_functions = {
27+
'numpy': 'is_numpy_namespace',
28+
'cupy': 'is_cupy_namespace',
29+
'torch': 'is_torch_namespace',
30+
'dask.array': 'is_dask_namespace',
31+
'jax.numpy': 'is_jax_namespace',
32+
'sparse': 'is_pydata_sparse_namespace',
33+
}
34+
35+
36+
@pytest.mark.parametrize('library', is_array_functions.keys())
37+
@pytest.mark.parametrize('func', is_array_functions.values())
2438
def test_is_xp_array(library, func):
2539
lib = import_(library)
2640
is_func = globals()[func]
2741

2842
x = lib.asarray([1, 2, 3])
2943

30-
assert is_func(x) == (func == is_functions[library])
44+
assert is_func(x) == (func == is_array_functions[library])
3145

3246
assert is_array_api_obj(x)
3347

48+
49+
@pytest.mark.parametrize('library', is_namespace_functions.keys())
50+
@pytest.mark.parametrize('func', is_namespace_functions.values())
51+
def test_is_xp_namespace(library, func):
52+
lib = import_(library)
53+
is_func = globals()[func]
54+
55+
assert is_func(lib) == (func == is_namespace_functions[library])
56+
57+
3458
@pytest.mark.parametrize("library", all_libraries)
3559
def test_device(library):
3660
xp = import_(library, wrapper=True)
@@ -64,8 +88,8 @@ def test_to_device_host(library):
6488
assert_allclose(x, expected)
6589

6690

67-
@pytest.mark.parametrize("target_library", is_functions.keys())
68-
@pytest.mark.parametrize("source_library", is_functions.keys())
91+
@pytest.mark.parametrize("target_library", is_array_functions.keys())
92+
@pytest.mark.parametrize("source_library", is_array_functions.keys())
6993
def test_asarray_cross_library(source_library, target_library, request):
7094
if source_library == "dask.array" and target_library == "torch":
7195
# Allow rest of test to execute instead of immediately xfailing
@@ -81,7 +105,7 @@ def test_asarray_cross_library(source_library, target_library, request):
81105
pytest.skip(reason="`sparse` does not allow implicit densification")
82106
src_lib = import_(source_library, wrapper=True)
83107
tgt_lib = import_(target_library, wrapper=True)
84-
is_tgt_type = globals()[is_functions[target_library]]
108+
is_tgt_type = globals()[is_array_functions[target_library]]
85109

86110
a = src_lib.asarray([1, 2, 3])
87111
b = tgt_lib.asarray(a)
@@ -96,7 +120,7 @@ def test_asarray_copy(library):
96120
# should be able to delete this.
97121
xp = import_(library, wrapper=True)
98122
asarray = xp.asarray
99-
is_lib_func = globals()[is_functions[library]]
123+
is_lib_func = globals()[is_array_functions[library]]
100124
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
101125

102126
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :

tests/test_vendoring.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_vendoring_torch():
2020

2121
uses_torch._test_torch()
2222

23+
2324
def test_vendoring_dask():
2425
from vendor_test import uses_dask
2526
uses_dask._test_dask()

vendor_test/uses_cupy.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Basic test that vendoring works
22

3-
from .vendored._compat import cupy as cp_compat
3+
from .vendored._compat import (
4+
cupy as cp_compat,
5+
is_cupy_array,
6+
is_cupy_namespace,
7+
)
48

59
import cupy as cp
610

@@ -16,3 +20,6 @@ def _test_cupy():
1620
assert isinstance(res, cp.ndarray)
1721

1822
cp.testing.assert_allclose(res, [1., 2., 9.])
23+
24+
assert is_cupy_array(res)
25+
assert is_cupy_namespace(cp) and is_cupy_namespace(cp_compat)

vendor_test/uses_dask.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Basic test that vendoring works
22

33
from .vendored._compat.dask import array as dask_compat
4+
from .vendored._compat import is_dask_array, is_dask_namespace
45

56
import dask.array as da
67
import numpy as np
@@ -17,3 +18,6 @@ def _test_dask():
1718
assert isinstance(res, da.Array)
1819

1920
np.testing.assert_allclose(res, [1., 2., 9.])
21+
22+
assert is_dask_array(res)
23+
assert is_dask_namespace(da) and is_dask_namespace(dask_compat)

vendor_test/uses_numpy.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Basic test that vendoring works
22

3-
from .vendored._compat import numpy as np_compat
3+
from .vendored._compat import (
4+
is_numpy_array,
5+
is_numpy_namespace,
6+
numpy as np_compat,
7+
)
8+
49

510
import numpy as np
611

@@ -16,3 +21,6 @@ def _test_numpy():
1621
assert isinstance(res, np.ndarray)
1722

1823
np.testing.assert_allclose(res, [1., 2., 9.])
24+
25+
assert is_numpy_array(res)
26+
assert is_numpy_namespace(np) and is_numpy_namespace(np_compat)

vendor_test/uses_torch.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Basic test that vendoring works
22

3-
from .vendored._compat import torch as torch_compat
3+
from .vendored._compat import (
4+
is_torch_array,
5+
is_torch_namespace,
6+
torch as torch_compat,
7+
)
48

59
import torch
610

@@ -20,3 +24,7 @@ def _test_torch():
2024
assert isinstance(res, torch.Tensor)
2125

2226
torch.testing.assert_allclose(res, [[1., 2., 3.]])
27+
28+
assert is_torch_array(res)
29+
assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)
30+

0 commit comments

Comments
 (0)