Skip to content

Commit ab6f323

Browse files
authored
Merge pull request #12 from rgommers/copy-keyword
Support `copy` keyword in `__array__` method
2 parents 08a13c0 + e4f775d commit ab6f323

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

Diff for: array_api_strict/_array_object.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,26 @@ def __repr__(self: Array, /) -> str:
122122

123123
# This function is not required by the spec, but we implement it here for
124124
# convenience so that np.asarray(array_api_strict.Array) will work.
125-
def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
125+
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
126126
"""
127127
Warning: this method is NOT part of the array API spec. Implementers
128128
of other libraries need not include it, and users should not assume it
129129
will be present in other implementations.
130130
131131
"""
132-
return np.asarray(self._array, dtype=dtype)
132+
# copy keyword is new in 2.0.0; for older versions don't use it
133+
# retry without that keyword.
134+
if np.__version__[0] < '2':
135+
return np.asarray(self._array, dtype=dtype)
136+
elif np.__version__.startswith('2.0.0-dev0'):
137+
# Handle dev version for which we can't know based on version
138+
# number whether or not the copy keyword is supported.
139+
try:
140+
return np.asarray(self._array, dtype=dtype, copy=copy)
141+
except TypeError:
142+
return np.asarray(self._array, dtype=dtype)
143+
else:
144+
return np.asarray(self._array, dtype=dtype, copy=copy)
133145

134146
# These are various helper functions to make the array behavior match the
135147
# spec in places where it either deviates from or is more strict than

0 commit comments

Comments
 (0)