Skip to content

Commit f57ac54

Browse files
committed
Add a CuPy specific implementation for asarray
This also fixes the device keyword to actually work, and fixes copy=None logic. copy=False is not implemented and requires upstream support, which should come in CuPy 14.
1 parent 440e1c1 commit f57ac54

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

array_api_compat/cupy/_aliases.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from __future__ import annotations
22

3-
from functools import partial
4-
53
import cupy as cp
64

75
from ..common import _aliases
86
from .._internal import get_xp
97

10-
asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
11-
asarray.__doc__ = _aliases._asarray.__doc__
12-
del partial
8+
from typing import TYPE_CHECKING
9+
if TYPE_CHECKING:
10+
from typing import Optional, Union
11+
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1312

1413
bool = cp.bool_
1514

@@ -62,6 +61,42 @@
6261
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6362
tensordot = get_xp(cp)(_aliases.tensordot)
6463

64+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
65+
def asarray(
66+
obj: Union[
67+
ndarray,
68+
bool,
69+
int,
70+
float,
71+
NestedSequence[bool | int | float],
72+
SupportsBufferProtocol,
73+
],
74+
/,
75+
*,
76+
dtype: Optional[Dtype] = None,
77+
device: Optional[Device] = None,
78+
copy: Optional[bool] = None,
79+
**kwargs,
80+
) -> ndarray:
81+
"""
82+
Array API compatibility wrapper for asarray().
83+
84+
See the corresponding documentation in the array library and/or the array API
85+
specification for more details.
86+
87+
'namespace' may be an array module namespace. This is needed to support
88+
conversion of sequences of Python scalars.
89+
"""
90+
with cp.cuda.Device(device):
91+
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
92+
# in asarray in numpy/_aliases.py.
93+
if copy is None:
94+
copy = False
95+
elif copy is False:
96+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
97+
98+
return cp.array(obj, copy=copy, dtype=dtype, **kwargs)
99+
65100
# These functions are completely new here. If the library already has them
66101
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67102
if hasattr(cp, 'vecdot'):
@@ -73,7 +108,7 @@
73108
else:
74109
isdtype = get_xp(cp)(_aliases.isdtype)
75110

76-
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
111+
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
77112
'acosh', 'asin', 'asinh', 'atan', 'atan2',
78113
'atanh', 'bitwise_left_shift', 'bitwise_invert',
79114
'bitwise_right_shift', 'concat', 'pow']

0 commit comments

Comments
 (0)