Skip to content

Commit abcef28

Browse files
committed
Support copy keyword in __array__ method
Closes gh-11
1 parent 6e0dfa0 commit abcef28

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

array_api_strict/_array_object.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,26 @@ def __repr__(self: Array, /) -> str:
122122

123123
# This function is not required by the spec, but we implement it here for
124124
# convenience so that np.asarray(array_api_strict.Array) will work.
125-
def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
125+
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
126126
"""
127127
Warning: this method is NOT part of the array API spec. Implementers
128128
of other libraries need not include it, and users should not assume it
129129
will be present in other implementations.
130130
131131
"""
132-
return np.asarray(self._array, dtype=dtype)
132+
# copy keyword is new in 2.0.0; for older versions don't use it
133+
# retry without that keyword.
134+
if np.__version__[0] < '2':
135+
return np.asarray(self._array, dtype=dtype)
136+
elif np.__version__.startswith('2.0.0-dev0'):
137+
# Handle dev version for which we can't know based on version
138+
# number whether or not the copy keyword is supported.
139+
try:
140+
return np.asarray(self._array, dtype=dtype, copy=copy)
141+
except TypeError:
142+
return np.asarray(self._array, dtype=dtype)
143+
else:
144+
return np.asarray(self._array, dtype=dtype, copy=copy)
133145

134146
# These are various helper functions to make the array behavior match the
135147
# spec in places where it either deviates from or is more strict than

0 commit comments

Comments
 (0)