Skip to content

Commit b9e64e5

Browse files
authored
Merge pull request #69 from asmeurer/rm-__array__
Remove __array__
2 parents ef71d85 + 5485345 commit b9e64e5

File tree

5 files changed

+66
-34
lines changed

5 files changed

+66
-34
lines changed

array_api_strict/_array_object.py

+29-22
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __hash__(self):
6666

6767
_default = object()
6868

69+
_allow_array = False
70+
6971
class Array:
7072
"""
7173
n-d array object for the array API namespace.
@@ -145,30 +147,35 @@ def __repr__(self: Array, /) -> str:
145147
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
146148
return prefix + mid + suffix
147149

148-
# This function is not required by the spec, but we implement it here for
149-
# convenience so that np.asarray(array_api_strict.Array) will work.
150+
# Disallow __array__, meaning calling `np.func()` on an array_api_strict
151+
# array will give an error. If we don't explicitly disallow it, NumPy
152+
# defaults to creating an object dtype array, which would lead to
153+
# confusing error messages at best and surprising bugs at worst.
154+
#
155+
# The alternative of course is to just support __array__, which is what we
156+
# used to do. But this isn't actually supported by the standard, so it can
157+
# lead to code assuming np.asarray(other_array) would always work in the
158+
# standard.
150159
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
151-
"""
152-
Warning: this method is NOT part of the array API spec. Implementers
153-
of other libraries need not include it, and users should not assume it
154-
will be present in other implementations.
155-
156-
"""
157-
if self._device != CPU_DEVICE:
158-
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
159-
# copy keyword is new in 2.0.0; for older versions don't use it
160-
# retry without that keyword.
161-
if np.__version__[0] < '2':
162-
return np.asarray(self._array, dtype=dtype)
163-
elif np.__version__.startswith('2.0.0-dev0'):
164-
# Handle dev version for which we can't know based on version
165-
# number whether or not the copy keyword is supported.
166-
try:
167-
return np.asarray(self._array, dtype=dtype, copy=copy)
168-
except TypeError:
160+
# We have to allow this to be internally enabled as there's no other
161+
# easy way to parse a list of Array objects in asarray().
162+
if _allow_array:
163+
if self._device != CPU_DEVICE:
164+
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
165+
# copy keyword is new in 2.0.0; for older versions don't use it
166+
# retry without that keyword.
167+
if np.__version__[0] < '2':
169168
return np.asarray(self._array, dtype=dtype)
170-
else:
171-
return np.asarray(self._array, dtype=dtype, copy=copy)
169+
elif np.__version__.startswith('2.0.0-dev0'):
170+
# Handle dev version for which we can't know based on version
171+
# number whether or not the copy keyword is supported.
172+
try:
173+
return np.asarray(self._array, dtype=dtype, copy=copy)
174+
except TypeError:
175+
return np.asarray(self._array, dtype=dtype)
176+
else:
177+
return np.asarray(self._array, dtype=dtype, copy=copy)
178+
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")
172179

173180
# These are various helper functions to make the array behavior match the
174181
# spec in places where it either deviates from or is more strict than

array_api_strict/_creation_functions.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
3+
from contextlib import contextmanager
44
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
55

66
if TYPE_CHECKING:
@@ -16,6 +16,19 @@
1616

1717
import numpy as np
1818

19+
@contextmanager
20+
def allow_array():
21+
"""
22+
Temporarily enable Array.__array__. This is needed for np.array to parse
23+
list of lists of Array objects.
24+
"""
25+
from . import _array_object
26+
original_value = _array_object._allow_array
27+
try:
28+
_array_object._allow_array = True
29+
yield
30+
finally:
31+
_array_object._allow_array = original_value
1932

2033
def _check_valid_dtype(dtype):
2134
# Note: Only spelling dtypes as the dtype objects is supported.
@@ -99,7 +112,8 @@ def asarray(
99112
# Give a better error message in this case. NumPy would convert this
100113
# to an object array. TODO: This won't handle large integers in lists.
101114
raise OverflowError("Integer out of bounds for array dtypes")
102-
res = np.array(obj, dtype=_np_dtype, copy=copy)
115+
with allow_array():
116+
res = np.array(obj, dtype=_np_dtype, copy=copy)
103117
return Array._new(res, device=device)
104118

105119

array_api_strict/_linalg.py

+2
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
267267
# default tolerance by max(M, N).
268268
if rtol is None:
269269
rtol = max(x.shape[-2:]) * finfo(x.dtype).eps
270+
if isinstance(rtol, Array):
271+
rtol = rtol._array
270272
return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device)
271273

272274
@requires_extension('linalg')

array_api_strict/tests/test_array_object.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -350,23 +350,19 @@ def test_array_properties():
350350
assert isinstance(b.mT, Array)
351351
assert b.mT.shape == (3, 2)
352352

353-
def test___array__():
354-
a = ones((2, 3), dtype=int16)
355-
assert np.asarray(a) is a._array
356-
b = np.asarray(a, dtype=np.float64)
357-
assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
358-
assert b.dtype == np.float64
359353

360354
def test_array_conversion():
361355
# Check that arrays on the CPU device can be converted to NumPy
362-
# but arrays on other devices can't
356+
# but arrays on other devices can't. Note this is testing the logic in
357+
# __array__, which is only used in asarray when converting lists of
358+
# arrays.
363359
a = ones((2, 3))
364-
np.asarray(a)
360+
asarray([a])
365361

366362
for device in ("device1", "device2"):
367363
a = ones((2, 3), device=array_api_strict.Device(device))
368364
with pytest.raises(RuntimeError, match="Can not convert array"):
369-
np.asarray(a)
365+
asarray([a])
370366

371367
def test_allow_newaxis():
372368
a = ones(5)

array_api_strict/tests/test_creation_functions.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
zeros,
2323
zeros_like,
2424
)
25-
from .._dtypes import float32, float64
25+
from .._dtypes import int16, float32, float64
2626
from .._array_object import Array, CPU_DEVICE, Device
2727
from .._flags import set_array_api_strict_flags
2828

@@ -97,6 +97,19 @@ def test_asarray_copy():
9797
a[0] = 0
9898
assert all(b[0] == 0)
9999

100+
def test_asarray_list_of_lists():
101+
a = asarray(1, dtype=int16)
102+
b = asarray([1], dtype=int16)
103+
res = asarray([a, a])
104+
assert res.shape == (2,)
105+
assert res.dtype == int16
106+
assert all(res == asarray([1, 1]))
107+
108+
res = asarray([b, b])
109+
assert res.shape == (2, 1)
110+
assert res.dtype == int16
111+
assert all(res == asarray([[1], [1]]))
112+
100113

101114
def test_asarray_device_inference():
102115
assert asarray([1, 2, 3]).device == CPU_DEVICE

0 commit comments

Comments
 (0)