|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from functools import partial |
4 |
| - |
5 | 3 | import cupy as cp
|
6 | 4 |
|
7 | 5 | from ..common import _aliases
|
8 | 6 | from .._internal import get_xp
|
9 | 7 |
|
10 |
| -asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy') |
11 |
| -asarray.__doc__ = _aliases._asarray.__doc__ |
12 |
| -del partial |
| 8 | +from typing import TYPE_CHECKING |
| 9 | +if TYPE_CHECKING: |
| 10 | + from typing import Optional, Union |
| 11 | + from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol |
13 | 12 |
|
14 | 13 | bool = cp.bool_
|
15 | 14 |
|
|
62 | 61 | matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
|
63 | 62 | tensordot = get_xp(cp)(_aliases.tensordot)
|
64 | 63 |
|
| 64 | +# asarray also adds the copy keyword, which is not present in numpy 1.0. |
| 65 | +def asarray( |
| 66 | + obj: Union[ |
| 67 | + ndarray, |
| 68 | + bool, |
| 69 | + int, |
| 70 | + float, |
| 71 | + NestedSequence[bool | int | float], |
| 72 | + SupportsBufferProtocol, |
| 73 | + ], |
| 74 | + /, |
| 75 | + *, |
| 76 | + dtype: Optional[Dtype] = None, |
| 77 | + device: Optional[Device] = None, |
| 78 | + copy: Optional[bool] = None, |
| 79 | + **kwargs, |
| 80 | +) -> ndarray: |
| 81 | + """ |
| 82 | + Array API compatibility wrapper for asarray(). |
| 83 | +
|
| 84 | + See the corresponding documentation in the array library and/or the array API |
| 85 | + specification for more details. |
| 86 | +
|
| 87 | + 'namespace' may be an array module namespace. This is needed to support |
| 88 | + conversion of sequences of Python scalars. |
| 89 | + """ |
| 90 | + with cp.cuda.Device(device): |
| 91 | + # cupy is like NumPy 1.26 (except without _CopyMode). See the comments |
| 92 | + # in asarray in numpy/_aliases.py. |
| 93 | + if copy is None: |
| 94 | + copy = False |
| 95 | + elif copy is False: |
| 96 | + raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") |
| 97 | + |
| 98 | + return cp.array(obj, copy=copy, dtype=dtype, **kwargs) |
| 99 | + |
65 | 100 | # These functions are completely new here. If the library already has them
|
66 | 101 | # (i.e., numpy 2.0), use the library version instead of our wrapper.
|
67 | 102 | if hasattr(cp, 'vecdot'):
|
|
73 | 108 | else:
|
74 | 109 | isdtype = get_xp(cp)(_aliases.isdtype)
|
75 | 110 |
|
76 |
| -__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', |
| 111 | +__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', |
77 | 112 | 'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
78 | 113 | 'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
79 | 114 | 'bitwise_right_shift', 'concat', 'pow']
|
|
0 commit comments