Skip to content

Commit d84983a

Browse files
committed
Structure the copy flag in cupy.asarray better
This way it is more future-proof for when cupy changes the meaning of copy=False.
1 parent a1eea09 commit d84983a

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

array_api_compat/cupy/_aliases.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6262
tensordot = get_xp(cp)(_aliases.tensordot)
6363

64+
_copy_default = object()
65+
6466
# asarray also adds the copy keyword, which is not present in numpy 1.0.
6567
def asarray(
6668
obj: Union[
@@ -75,7 +77,7 @@ def asarray(
7577
*,
7678
dtype: Optional[Dtype] = None,
7779
device: Optional[Device] = None,
78-
copy: Optional[bool] = None,
80+
copy: Optional[bool] = _copy_default,
7981
**kwargs,
8082
) -> ndarray:
8183
"""
@@ -90,12 +92,23 @@ def asarray(
9092
with cp.cuda.Device(device):
9193
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
9294
# in asarray in numpy/_aliases.py.
93-
if copy is None:
94-
copy = False
95-
elif copy is False:
96-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
95+
if copy is not _copy_default:
96+
# A future version of CuPy will change the meaning of copy=False
97+
# to mean no-copy. We don't know for certain what version it will
98+
# be yet, so to avoid breaking that version, we use a different
99+
# default value for copy so asarray(obj) with no copy kwarg will
100+
# always do the copy-if-needed behavior.
101+
102+
# This will still need to be updated to remove the
103+
# NotImplementedError for copy=False, but at least this won't
104+
# break the default or existing behavior.
105+
if copy is None:
106+
copy = False
107+
elif copy is False:
108+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
109+
kwargs['copy'] = copy
97110

98-
return cp.array(obj, copy=copy, dtype=dtype, **kwargs)
111+
return cp.array(obj, dtype=dtype, **kwargs)
99112

100113
# These functions are completely new here. If the library already has them
101114
# (i.e., numpy 2.0), use the library version instead of our wrapper.

0 commit comments

Comments
 (0)