Skip to content

Commit 558e4ed

Browse files
committed
Move the asarray numpy implementation to numpy/_aliases
This now properly supports all relevant versions of NumPy. This also removes the asarray_numpy function from the namespace. I plan to do the same for cupy and dask as well.
1 parent fa36c20 commit 558e4ed

File tree

3 files changed

+98
-16
lines changed

3 files changed

+98
-16
lines changed

array_api_compat/numpy/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,10 @@
2121

2222
from ..common._helpers import * # noqa: F403
2323

24+
try:
25+
# Used in asarray(). Not present in older versions.
26+
from numpy import _CopyMode
27+
except ImportError:
28+
pass
29+
2430
__array_api_version__ = '2022.12'

array_api_compat/numpy/_aliases.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from __future__ import annotations
22

3-
from functools import partial
4-
53
from ..common import _aliases
64

75
from .._internal import get_xp
86

9-
asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
10-
asarray.__doc__ = _aliases._asarray.__doc__
11-
del partial
7+
from typing import TYPE_CHECKING
8+
if TYPE_CHECKING:
9+
from typing import Optional, Union
10+
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1211

1312
import numpy as np
1413
bool = np.bool_
@@ -62,6 +61,65 @@
6261
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6362
tensordot = get_xp(np)(_aliases.tensordot)
6463

64+
def _supports_buffer_protocol(obj):
65+
try:
66+
memoryview(obj)
67+
except TypeError:
68+
return False
69+
return True
70+
71+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
72+
# asarray() is different enough between numpy, cupy, and dask, the logic
73+
# complicated enough that it's easier to define it separately for each module
74+
# rather than trying to combine everything into one function in common/
75+
def asarray(
76+
obj: Union[
77+
ndarray,
78+
bool,
79+
int,
80+
float,
81+
NestedSequence[bool | int | float],
82+
SupportsBufferProtocol,
83+
],
84+
/,
85+
*,
86+
dtype: Optional[Dtype] = None,
87+
device: Optional[Device] = None,
88+
copy: "Optional[Union[bool, np._CopyMode]]" = None,
89+
**kwargs,
90+
) -> ndarray:
91+
"""
92+
Array API compatibility wrapper for asarray().
93+
94+
See the corresponding documentation in the array library and/or the array API
95+
specification for more details.
96+
97+
'namespace' may be an array module namespace. This is needed to support
98+
conversion of sequences of Python scalars.
99+
"""
100+
if np.__version__[0] >= '2':
101+
# NumPy 2.0 asarray() is completely array API compatible. No need for
102+
# the complicated logic below
103+
return np.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
104+
105+
if device not in ["cpu", None]:
106+
raise ValueError(f"Unsupported device for NumPy: {device!r}")
107+
108+
if hasattr(np, '_CopyMode'):
109+
if copy is None:
110+
copy = np._CopyMode.IF_NEEDED
111+
elif copy is False:
112+
copy = np._CopyMode.NEVER
113+
elif copy is True:
114+
copy = np._CopyMode.ALWAYS
115+
else:
116+
# Not present in older NumPys. In this case, we cannot really support
117+
# copy=False.
118+
if copy is False:
119+
raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.")
120+
121+
return np.array(obj, copy=copy, dtype=dtype, **kwargs)
122+
65123
# These functions are completely new here. If the library already has them
66124
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67125
if hasattr(np, 'vecdot'):
@@ -73,7 +131,7 @@
73131
else:
74132
isdtype = get_xp(np)(_aliases.isdtype)
75133

76-
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
134+
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
77135
'acosh', 'asin', 'asinh', 'atan', 'atan2',
78136
'atanh', 'bitwise_left_shift', 'bitwise_invert',
79137
'bitwise_right_shift', 'concat', 'pow']

tests/test_common.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def test_asarray_copy(library):
9393
is_lib_func = globals()[is_functions[library]]
9494
all = xp.all
9595

96+
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
97+
supports_copy_false = False
98+
else:
99+
supports_copy_false = True
100+
96101
a = asarray([1])
97102
b = asarray(a, copy=True)
98103
assert is_lib_func(b)
@@ -101,13 +106,20 @@ def test_asarray_copy(library):
101106
assert all(a[0] == 0)
102107

103108
a = asarray([1])
104-
b = asarray(a, copy=False)
105-
assert is_lib_func(b)
106-
a[0] = 0
107-
assert all(b[0] == 0)
109+
if supports_copy_false:
110+
b = asarray(a, copy=False)
111+
assert is_lib_func(b)
112+
a[0] = 0
113+
assert all(b[0] == 0)
114+
else:
115+
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
108116

109117
a = asarray([1])
110-
pytest.raises(ValueError, lambda: asarray(a, copy=False, dtype=xp.float64))
118+
if supports_copy_false:
119+
pytest.raises(ValueError, lambda: asarray(a, copy=False,
120+
dtype=xp.float64))
121+
else:
122+
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
111123

112124
a = asarray([1])
113125
b = asarray(a, copy=None)
@@ -131,7 +143,10 @@ def test_asarray_copy(library):
131143
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
132144
asarray(obj, copy=True) # No error
133145
asarray(obj, copy=None) # No error
134-
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
146+
if supports_copy_false:
147+
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
148+
else:
149+
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
135150

136151
# Use the standard library array to test the buffer protocol
137152
a = array.array('f', [1.0])
@@ -141,10 +156,13 @@ def test_asarray_copy(library):
141156
assert all(b[0] == 1.0)
142157

143158
a = array.array('f', [1.0])
144-
b = asarray(a, copy=False)
145-
assert is_lib_func(b)
146-
a[0] = 0.0
147-
assert all(b[0] == 0.0)
159+
if supports_copy_false:
160+
b = asarray(a, copy=False)
161+
assert is_lib_func(b)
162+
a[0] = 0.0
163+
assert all(b[0] == 0.0)
164+
else:
165+
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
148166

149167
a = array.array('f', [1.0])
150168
b = asarray(a, copy=None)

0 commit comments

Comments
 (0)