From abcef28054dea589205a120423996a280130cadc Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Sat, 2 Mar 2024 08:10:05 +0100 Subject: [PATCH] Support `copy` keyword in `__array__` method Closes gh-11 --- array_api_strict/_array_object.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 76eb1f1..b659722 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -122,14 +122,26 @@ def __repr__(self: Array, /) -> str: # This function is not required by the spec, but we implement it here for # convenience so that np.asarray(array_api_strict.Array) will work. - def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]: + def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: """ Warning: this method is NOT part of the array API spec. Implementers of other libraries need not include it, and users should not assume it will be present in other implementations. """ - return np.asarray(self._array, dtype=dtype) + # copy keyword is new in 2.0.0; for older versions don't use it + # retry without that keyword. + if np.__version__[0] < '2': + return np.asarray(self._array, dtype=dtype) + elif np.__version__.startswith('2.0.0-dev0'): + # Handle dev version for which we can't know based on version + # number whether or not the copy keyword is supported. + try: + return np.asarray(self._array, dtype=dtype, copy=copy) + except TypeError: + return np.asarray(self._array, dtype=dtype) + else: + return np.asarray(self._array, dtype=dtype, copy=copy) # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than