diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index c57d6ed..0de6b8a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -66,7 +66,9 @@ def __hash__(self): _default = object() -_allow_array = False +# See https://github.com/data-apis/array-api-strict/issues/67 and the comment +# on __array__ below. +_allow_array = True class Array: """ @@ -147,15 +149,18 @@ def __repr__(self: Array, /) -> str: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix - # Disallow __array__, meaning calling `np.func()` on an array_api_strict - # array will give an error. If we don't explicitly disallow it, NumPy - # defaults to creating an object dtype array, which would lead to - # confusing error messages at best and surprising bugs at worst. - # - # The alternative of course is to just support __array__, which is what we - # used to do. But this isn't actually supported by the standard, so it can + # In the future, _allow_array will be set to False, which will disallow + # __array__. This means calling `np.func()` on an array_api_strict array + # will give an error. If we don't explicitly disallow it, NumPy defaults + # to creating an object dtype array, which would lead to confusing error + # messages at best and surprising bugs at worst. The reason for doing this + # is that __array__ is not actually supported by the standard, so it can # lead to code assuming np.asarray(other_array) would always work in the # standard. + # + # This was implemented historically for compatibility, and removing it has + # caused issues for some libraries (see + # https://github.com/data-apis/array-api-strict/issues/67). def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: # We have to allow this to be internally enabled as there's no other # easy way to parse a list of Array objects in asarray(). diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 4f843ba..29b7d17 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -364,6 +364,23 @@ def test_array_conversion(): with pytest.raises(RuntimeError, match="Can not convert array"): asarray([a]) +def test__array__(): + # __array__ should work for now + a = ones((2, 3)) + np.array(a) + + # Test the _allow_array private global flag for disabling it in the + # future. + from .. import _array_object + original_value = _array_object._allow_array + try: + _array_object._allow_array = False + a = ones((2, 3)) + with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"): + np.array(a) + finally: + _array_object._allow_array = original_value + def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :]