|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from functools import partial |
4 |
| - |
5 | 3 | from ..common import _aliases
|
6 | 4 |
|
7 | 5 | from .._internal import get_xp
|
8 | 6 |
|
9 |
| -asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') |
10 |
| -asarray.__doc__ = _aliases._asarray.__doc__ |
11 |
| -del partial |
| 7 | +from typing import TYPE_CHECKING |
| 8 | +if TYPE_CHECKING: |
| 9 | + from typing import Optional, Union |
| 10 | + from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol |
12 | 11 |
|
13 | 12 | import numpy as np
|
14 | 13 | bool = np.bool_
|
|
62 | 61 | matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
|
63 | 62 | tensordot = get_xp(np)(_aliases.tensordot)
|
64 | 63 |
|
| 64 | +def _supports_buffer_protocol(obj): |
| 65 | + try: |
| 66 | + memoryview(obj) |
| 67 | + except TypeError: |
| 68 | + return False |
| 69 | + return True |
| 70 | + |
| 71 | +# asarray also adds the copy keyword, which is not present in numpy 1.0. |
| 72 | +# asarray() is different enough between numpy, cupy, and dask, the logic |
| 73 | +# complicated enough that it's easier to define it separately for each module |
| 74 | +# rather than trying to combine everything into one function in common/ |
| 75 | +def asarray( |
| 76 | + obj: Union[ |
| 77 | + ndarray, |
| 78 | + bool, |
| 79 | + int, |
| 80 | + float, |
| 81 | + NestedSequence[bool | int | float], |
| 82 | + SupportsBufferProtocol, |
| 83 | + ], |
| 84 | + /, |
| 85 | + *, |
| 86 | + dtype: Optional[Dtype] = None, |
| 87 | + device: Optional[Device] = None, |
| 88 | + copy: "Optional[Union[bool, np._CopyMode]]" = None, |
| 89 | + **kwargs, |
| 90 | +) -> ndarray: |
| 91 | + """ |
| 92 | + Array API compatibility wrapper for asarray(). |
| 93 | +
|
| 94 | + See the corresponding documentation in the array library and/or the array API |
| 95 | + specification for more details. |
| 96 | +
|
| 97 | + 'namespace' may be an array module namespace. This is needed to support |
| 98 | + conversion of sequences of Python scalars. |
| 99 | + """ |
| 100 | + if np.__version__[0] >= '2': |
| 101 | + # NumPy 2.0 asarray() is completely array API compatible. No need for |
| 102 | + # the complicated logic below |
| 103 | + return np.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) |
| 104 | + |
| 105 | + if device not in ["cpu", None]: |
| 106 | + raise ValueError(f"Unsupported device for NumPy: {device!r}") |
| 107 | + |
| 108 | + if hasattr(np, '_CopyMode'): |
| 109 | + if copy is None: |
| 110 | + copy = np._CopyMode.IF_NEEDED |
| 111 | + elif copy is False: |
| 112 | + copy = np._CopyMode.NEVER |
| 113 | + elif copy is True: |
| 114 | + copy = np._CopyMode.ALWAYS |
| 115 | + else: |
| 116 | + # Not present in older NumPys. In this case, we cannot really support |
| 117 | + # copy=False. |
| 118 | + if copy is False: |
| 119 | + raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") |
| 120 | + |
| 121 | + return np.array(obj, copy=copy, dtype=dtype, **kwargs) |
| 122 | + |
65 | 123 | # These functions are completely new here. If the library already has them
|
66 | 124 | # (i.e., numpy 2.0), use the library version instead of our wrapper.
|
67 | 125 | if hasattr(np, 'vecdot'):
|
|
73 | 131 | else:
|
74 | 132 | isdtype = get_xp(np)(_aliases.isdtype)
|
75 | 133 |
|
76 |
| -__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', |
| 134 | +__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', |
77 | 135 | 'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
78 | 136 | 'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
79 | 137 | 'bitwise_right_shift', 'concat', 'pow']
|
|
0 commit comments