Skip to content

Commit 47ddd17

Browse files
committed
Test for read-only arrays (data-apis#205)
1 parent 0d99436 commit 47ddd17

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

array_api_compat/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
NumPy Array API compatibility library
33
4-
This is a small wrapper around NumPy and CuPy that is compatible with the
5-
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6-
https://numpy.org/neps/nep-0047-array-api-standard.html.
4+
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is
5+
compatible with the Array API standard https://data-apis.org/array-api/latest/.
6+
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
77
88
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with

array_api_compat/common/_helpers.py

+14
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
787787
return x
788788
return x.to_device(device, stream=stream)
789789

790+
790791
def size(x):
791792
"""
792793
Return the total number of elements of x.
@@ -801,6 +802,18 @@ def size(x):
801802
return None
802803
return math.prod(x.shape)
803804

805+
806+
def is_writeable_array(x) -> bool:
807+
"""
808+
Return False if ``x.__setitem__`` is expected to raise; True otherwise
809+
"""
810+
if is_numpy_array(x):
811+
return x.flags.writeable
812+
if is_jax_array(x) or is_pydata_sparse_array(x):
813+
return False
814+
return True
815+
816+
804817
__all__ = [
805818
"array_namespace",
806819
"device",
@@ -821,6 +834,7 @@ def size(x):
821834
"is_ndonnx_namespace",
822835
"is_pydata_sparse_array",
823836
"is_pydata_sparse_namespace",
837+
"is_writeable_array",
824838
"size",
825839
"to_device",
826840
]

docs/helper-functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ yet.
5151
.. autofunction:: is_jax_array
5252
.. autofunction:: is_pydata_sparse_array
5353
.. autofunction:: is_ndonnx_array
54+
.. autofunction:: is_writeable_array
5455
.. autofunction:: is_numpy_namespace
5556
.. autofunction:: is_cupy_namespace
5657
.. autofunction:: is_torch_namespace

tests/test_common.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
66
)
77

8-
from array_api_compat import is_array_api_obj, device, to_device
8+
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
99

1010
from ._helpers import import_, wrapped_libraries, all_libraries
1111

@@ -55,6 +55,24 @@ def test_is_xp_namespace(library, func):
5555
assert is_func(lib) == (func == is_namespace_functions[library])
5656

5757

58+
@pytest.mark.parametrize("library", all_libraries)
59+
def test_is_writeable_array(library):
60+
lib = import_(library)
61+
x = lib.asarray([1, 2, 3])
62+
if is_writeable_array(x):
63+
x[1] = 4
64+
else:
65+
with pytest.raises((TypeError, ValueError)):
66+
x[1] = 4
67+
68+
69+
def test_is_writeable_array_numpy():
70+
x = np.asarray([1, 2, 3])
71+
assert is_writeable_array(x)
72+
x.flags.writeable = False
73+
assert not is_writeable_array(x)
74+
75+
5876
@pytest.mark.parametrize("library", all_libraries)
5977
def test_device(library):
6078
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)