Skip to content

Commit 0c9bb5e

Browse files
committed
Re-enable __array__
Removing it caused issues for SciPy (data-apis#67). I have left the flag in to make it easy to remove it in the future. I also considered raising a warning in __array__, but this is also difficult to handle data-apis#91
1 parent 4e5ff09 commit 0c9bb5e

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

array_api_strict/_array_object.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def __hash__(self):
6666

6767
_default = object()
6868

69-
_allow_array = False
69+
# See https://github.com/data-apis/array-api-strict/issues/67 and the comment
70+
# on __array__ below.
71+
_allow_array = True
7072

7173
class Array:
7274
"""
@@ -147,15 +149,18 @@ def __repr__(self: Array, /) -> str:
147149
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
148150
return prefix + mid + suffix
149151

150-
# Disallow __array__, meaning calling `np.func()` on an array_api_strict
151-
# array will give an error. If we don't explicitly disallow it, NumPy
152-
# defaults to creating an object dtype array, which would lead to
153-
# confusing error messages at best and surprising bugs at worst.
154-
#
155-
# The alternative of course is to just support __array__, which is what we
156-
# used to do. But this isn't actually supported by the standard, so it can
152+
# In the future, _allow_array will be set to False, which will disallow
153+
# __array__. This means calling `np.func()` on an array_api_strict array
154+
# will give an error. If we don't explicitly disallow it, NumPy defaults
155+
# to creating an object dtype array, which would lead to confusing error
156+
# messages at best and surprising bugs at worst. The reason for doing this
157+
# is that __array__ is not actually supported by the standard, so it can
157158
# lead to code assuming np.asarray(other_array) would always work in the
158159
# standard.
160+
#
161+
# This was implemented historically for compatibility, and removing it has
162+
# caused issues for some libraries (see
163+
# https://github.com/data-apis/array-api-strict/issues/67).
159164
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
160165
# We have to allow this to be internally enabled as there's no other
161166
# easy way to parse a list of Array objects in asarray().

array_api_strict/tests/test_array_object.py

+17
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,23 @@ def test_array_conversion():
364364
with pytest.raises(RuntimeError, match="Can not convert array"):
365365
asarray([a])
366366

367+
def test__array__():
368+
# __array__ should work for now
369+
a = ones((2, 3))
370+
np.array(a)
371+
372+
# Test the _allow_array private global flag for disabling it in the
373+
# future.
374+
from .. import _array_object
375+
original_value = _array_object._allow_array
376+
try:
377+
_array_object._allow_array = False
378+
a = ones((2, 3))
379+
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
380+
np.array(a)
381+
finally:
382+
_array_object._allow_array = original_value
383+
367384
def test_allow_newaxis():
368385
a = ones(5)
369386
indexed_a = a[None, :]

0 commit comments

Comments
 (0)