Skip to content

Commit 8e3f0b6

Browse files
committed
Merge branch 'main' into more-2023
2 parents 4c9dd0e + 6f9edc7 commit 8e3f0b6

File tree

13 files changed

+268
-22
lines changed

13 files changed

+268
-22
lines changed

.github/workflows/publish-package.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
if: >-
9595
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags'))
9696
|| (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true')
97-
uses: pypa/gh-action-pypi-publish@v1.9.0
97+
uses: pypa/gh-action-pypi-publish@v1.10.1
9898
with:
9999
repository-url: https://test.pypi.org/legacy/
100100
print-hash: true
@@ -107,6 +107,6 @@ jobs:
107107

108108
- name: Publish distribution 📦 to PyPI
109109
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
110-
uses: pypa/gh-action-pypi-publish@v1.9.0
110+
uses: pypa/gh-action-pypi-publish@v1.10.1
111111
with:
112112
print-hash: true

array_api_compat/common/_helpers.py

+168-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def is_jax_array(x):
202202

203203
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204204

205-
206205
def is_pydata_sparse_array(x) -> bool:
207206
"""
208207
Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,166 @@ def is_array_api_obj(x):
255254
or is_pydata_sparse_array(x) \
256255
or hasattr(x, '__array_namespace__')
257256

257+
def _compat_module_name():
258+
assert __name__.endswith('.common._helpers')
259+
return __name__.removesuffix('.common._helpers')
260+
261+
def is_numpy_namespace(xp) -> bool:
262+
"""
263+
Returns True if `xp` is a NumPy namespace.
264+
265+
This includes both NumPy itself and the version wrapped by array-api-compat.
266+
267+
See Also
268+
--------
269+
270+
array_namespace
271+
is_cupy_namespace
272+
is_torch_namespace
273+
is_ndonnx_namespace
274+
is_dask_namespace
275+
is_jax_namespace
276+
is_pydata_sparse_namespace
277+
is_array_api_strict_namespace
278+
"""
279+
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280+
281+
def is_cupy_namespace(xp) -> bool:
282+
"""
283+
Returns True if `xp` is a CuPy namespace.
284+
285+
This includes both CuPy itself and the version wrapped by array-api-compat.
286+
287+
See Also
288+
--------
289+
290+
array_namespace
291+
is_numpy_namespace
292+
is_torch_namespace
293+
is_ndonnx_namespace
294+
is_dask_namespace
295+
is_jax_namespace
296+
is_pydata_sparse_namespace
297+
is_array_api_strict_namespace
298+
"""
299+
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300+
301+
def is_torch_namespace(xp) -> bool:
302+
"""
303+
Returns True if `xp` is a PyTorch namespace.
304+
305+
This includes both PyTorch itself and the version wrapped by array-api-compat.
306+
307+
See Also
308+
--------
309+
310+
array_namespace
311+
is_numpy_namespace
312+
is_cupy_namespace
313+
is_ndonnx_namespace
314+
is_dask_namespace
315+
is_jax_namespace
316+
is_pydata_sparse_namespace
317+
is_array_api_strict_namespace
318+
"""
319+
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320+
321+
322+
def is_ndonnx_namespace(xp):
323+
"""
324+
Returns True if `xp` is an NDONNX namespace.
325+
326+
See Also
327+
--------
328+
329+
array_namespace
330+
is_numpy_namespace
331+
is_cupy_namespace
332+
is_torch_namespace
333+
is_dask_namespace
334+
is_jax_namespace
335+
is_pydata_sparse_namespace
336+
is_array_api_strict_namespace
337+
"""
338+
return xp.__name__ == 'ndonnx'
339+
340+
def is_dask_namespace(xp):
341+
"""
342+
Returns True if `xp` is a Dask namespace.
343+
344+
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
345+
346+
See Also
347+
--------
348+
349+
array_namespace
350+
is_numpy_namespace
351+
is_cupy_namespace
352+
is_torch_namespace
353+
is_ndonnx_namespace
354+
is_jax_namespace
355+
is_pydata_sparse_namespace
356+
is_array_api_strict_namespace
357+
"""
358+
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359+
360+
def is_jax_namespace(xp):
361+
"""
362+
Returns True if `xp` is a JAX namespace.
363+
364+
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
365+
older versions of JAX.
366+
367+
See Also
368+
--------
369+
370+
array_namespace
371+
is_numpy_namespace
372+
is_cupy_namespace
373+
is_torch_namespace
374+
is_ndonnx_namespace
375+
is_dask_namespace
376+
is_pydata_sparse_namespace
377+
is_array_api_strict_namespace
378+
"""
379+
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380+
381+
def is_pydata_sparse_namespace(xp):
382+
"""
383+
Returns True if `xp` is a pydata/sparse namespace.
384+
385+
See Also
386+
--------
387+
388+
array_namespace
389+
is_numpy_namespace
390+
is_cupy_namespace
391+
is_torch_namespace
392+
is_ndonnx_namespace
393+
is_dask_namespace
394+
is_jax_namespace
395+
is_array_api_strict_namespace
396+
"""
397+
return xp.__name__ == 'sparse'
398+
399+
def is_array_api_strict_namespace(xp):
400+
"""
401+
Returns True if `xp` is an array-api-strict namespace.
402+
403+
See Also
404+
--------
405+
406+
array_namespace
407+
is_numpy_namespace
408+
is_cupy_namespace
409+
is_torch_namespace
410+
is_ndonnx_namespace
411+
is_dask_namespace
412+
is_jax_namespace
413+
is_pydata_sparse_namespace
414+
"""
415+
return xp.__name__ == 'array_api_strict'
416+
258417
def _check_api_version(api_version):
259418
if api_version == '2021.12':
260419
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
@@ -643,13 +802,21 @@ def size(x):
643802
"device",
644803
"get_namespace",
645804
"is_array_api_obj",
805+
"is_array_api_strict_namespace",
646806
"is_cupy_array",
807+
"is_cupy_namespace",
647808
"is_dask_array",
809+
"is_dask_namespace",
648810
"is_jax_array",
811+
"is_jax_namespace",
649812
"is_numpy_array",
813+
"is_numpy_namespace",
650814
"is_torch_array",
815+
"is_torch_namespace",
651816
"is_ndonnx_array",
817+
"is_ndonnx_namespace",
652818
"is_pydata_sparse_array",
819+
"is_pydata_sparse_namespace",
653820
"size",
654821
"to_device",
655822
]

array_api_compat/cupy/_aliases.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ def asarray(
108108

109109
return cp.array(obj, dtype=dtype, **kwargs)
110110

111+
def sign(x: ndarray, /) -> ndarray:
112+
# CuPy sign() does not propagate nans. See
113+
# https://github.com/data-apis/array-api-compat/issues/136
114+
out = cp.sign(x)
115+
out[cp.isnan(x)] = cp.nan
116+
return out
117+
111118
# These functions are completely new here. If the library already has them
112119
# (i.e., numpy 2.0), use the library version instead of our wrapper.
113120
if hasattr(cp, 'vecdot'):
@@ -129,6 +136,6 @@ def asarray(
129136
'acos', 'acosh', 'asin', 'asinh', 'atan',
130137
'atan2', 'atanh', 'bitwise_left_shift',
131138
'bitwise_invert', 'bitwise_right_shift',
132-
'concat', 'pow']
139+
'concat', 'pow', 'sign']
133140

134141
_all_ignore = ['cp', 'get_xp']

array_api_compat/torch/_aliases.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,21 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
709709
axis = 0
710710
return torch.index_select(x, axis, indices, **kwargs)
711711

712+
def sign(x: array, /) -> array:
713+
# torch sign() does not support complex numbers and does not propagate
714+
# nans. See https://github.com/data-apis/array-api-compat/issues/136
715+
if x.dtype.is_complex:
716+
out = x/torch.abs(x)
717+
# sign(0) = 0 but the above formula would give nan
718+
out[x == 0+0j] = 0+0j
719+
return out
720+
else:
721+
out = torch.sign(x)
722+
if x.dtype.is_floating_point:
723+
out[torch.isnan(x)] = torch.nan
724+
return out
725+
726+
712727
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
713728
'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
714729
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
@@ -722,6 +737,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
722737
'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult',
723738
'UniqueInverseResult', 'unique_all', 'unique_counts',
724739
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
725-
'vecdot', 'tensordot', 'isdtype', 'take']
740+
'vecdot', 'tensordot', 'isdtype', 'take', 'sign']
726741

727742
_all_ignore = ['torch', 'get_xp']

cupy-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ array_api_tests/test_special_cases.py::test_unary[expm1(x_i is -0) -> -0]
160160
array_api_tests/test_special_cases.py::test_unary[floor(x_i is -0) -> -0]
161161
array_api_tests/test_special_cases.py::test_unary[log1p(x_i is -0) -> -0]
162162
array_api_tests/test_special_cases.py::test_unary[round(x_i is -0) -> -0]
163-
array_api_tests/test_special_cases.py::test_unary[sign(x_i is NaN) -> NaN]
164163
array_api_tests/test_special_cases.py::test_unary[sin(x_i is -0) -> -0]
165164
array_api_tests/test_special_cases.py::test_unary[sinh(x_i is -0) -> -0]
166165
array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0]

docs/helper-functions.rst

+9-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ instead, which would be wrapped.
4040
Inspection Helpers
4141
------------------
4242

43-
These convenience functions can be used to test if an array comes from a
43+
These convenience functions can be used to test if an array or namespace comes from a
4444
specific library without importing that library if it hasn't been imported
4545
yet.
4646

@@ -51,3 +51,11 @@ yet.
5151
.. autofunction:: is_jax_array
5252
.. autofunction:: is_pydata_sparse_array
5353
.. autofunction:: is_ndonnx_array
54+
.. autofunction:: is_numpy_namespace
55+
.. autofunction:: is_cupy_namespace
56+
.. autofunction:: is_torch_namespace
57+
.. autofunction:: is_dask_namespace
58+
.. autofunction:: is_jax_namespace
59+
.. autofunction:: is_pydata_sparse_namespace
60+
.. autofunction:: is_ndonnx_namespace
61+
.. autofunction:: is_array_api_strict_namespace

tests/test_common.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2-
is_dask_array, is_jax_array, is_pydata_sparse_array)
1+
from array_api_compat import ( # noqa: F401
2+
is_numpy_array, is_cupy_array, is_torch_array,
3+
is_dask_array, is_jax_array, is_pydata_sparse_array,
4+
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
5+
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
6+
)
37

48
from array_api_compat import is_array_api_obj, device, to_device
59

@@ -10,7 +14,7 @@
1014
import array
1115
from numpy.testing import assert_allclose
1216

13-
is_functions = {
17+
is_array_functions = {
1418
'numpy': 'is_numpy_array',
1519
'cupy': 'is_cupy_array',
1620
'torch': 'is_torch_array',
@@ -19,18 +23,38 @@
1923
'sparse': 'is_pydata_sparse_array',
2024
}
2125

22-
@pytest.mark.parametrize('library', is_functions.keys())
23-
@pytest.mark.parametrize('func', is_functions.values())
26+
is_namespace_functions = {
27+
'numpy': 'is_numpy_namespace',
28+
'cupy': 'is_cupy_namespace',
29+
'torch': 'is_torch_namespace',
30+
'dask.array': 'is_dask_namespace',
31+
'jax.numpy': 'is_jax_namespace',
32+
'sparse': 'is_pydata_sparse_namespace',
33+
}
34+
35+
36+
@pytest.mark.parametrize('library', is_array_functions.keys())
37+
@pytest.mark.parametrize('func', is_array_functions.values())
2438
def test_is_xp_array(library, func):
2539
lib = import_(library)
2640
is_func = globals()[func]
2741

2842
x = lib.asarray([1, 2, 3])
2943

30-
assert is_func(x) == (func == is_functions[library])
44+
assert is_func(x) == (func == is_array_functions[library])
3145

3246
assert is_array_api_obj(x)
3347

48+
49+
@pytest.mark.parametrize('library', is_namespace_functions.keys())
50+
@pytest.mark.parametrize('func', is_namespace_functions.values())
51+
def test_is_xp_namespace(library, func):
52+
lib = import_(library)
53+
is_func = globals()[func]
54+
55+
assert is_func(lib) == (func == is_namespace_functions[library])
56+
57+
3458
@pytest.mark.parametrize("library", all_libraries)
3559
def test_device(library):
3660
xp = import_(library, wrapper=True)
@@ -64,8 +88,8 @@ def test_to_device_host(library):
6488
assert_allclose(x, expected)
6589

6690

67-
@pytest.mark.parametrize("target_library", is_functions.keys())
68-
@pytest.mark.parametrize("source_library", is_functions.keys())
91+
@pytest.mark.parametrize("target_library", is_array_functions.keys())
92+
@pytest.mark.parametrize("source_library", is_array_functions.keys())
6993
def test_asarray_cross_library(source_library, target_library, request):
7094
if source_library == "dask.array" and target_library == "torch":
7195
# Allow rest of test to execute instead of immediately xfailing
@@ -81,7 +105,7 @@ def test_asarray_cross_library(source_library, target_library, request):
81105
pytest.skip(reason="`sparse` does not allow implicit densification")
82106
src_lib = import_(source_library, wrapper=True)
83107
tgt_lib = import_(target_library, wrapper=True)
84-
is_tgt_type = globals()[is_functions[target_library]]
108+
is_tgt_type = globals()[is_array_functions[target_library]]
85109

86110
a = src_lib.asarray([1, 2, 3])
87111
b = tgt_lib.asarray(a)
@@ -96,7 +120,7 @@ def test_asarray_copy(library):
96120
# should be able to delete this.
97121
xp = import_(library, wrapper=True)
98122
asarray = xp.asarray
99-
is_lib_func = globals()[is_functions[library]]
123+
is_lib_func = globals()[is_array_functions[library]]
100124
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
101125

102126
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :

tests/test_vendoring.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_vendoring_torch():
2020

2121
uses_torch._test_torch()
2222

23+
2324
def test_vendoring_dask():
2425
from vendor_test import uses_dask
2526
uses_dask._test_dask()

0 commit comments

Comments
 (0)