Skip to content

Commit 8565b2c

Browse files
authored
Fix asarray(copy=False) (#15)
For NumPy 2.0, this is implemented directly. For NumPy 1, we emulate it by checking if asarray() creates a copy or not. This also removes support for the np._CopyMode enum in asarray(), as this is not portable.
1 parent d10e2df commit 8565b2c

File tree

3 files changed

+70
-22
lines changed

3 files changed

+70
-22
lines changed

Diff for: array-api-tests-xfails.txt

-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# copy=False is not yet implemented
2-
# https://github.com/numpy/numpy/pull/25168
3-
array_api_tests/test_creation_functions.py::test_asarray_arrays
4-
51
# Known special case issue in NumPy. Not worth working around here
62
# https://github.com/numpy/numpy/issues/21213
73
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]

Diff for: array_api_strict/_creation_functions.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
NestedSequence,
1212
SupportsBufferProtocol,
1313
)
14-
from collections.abc import Sequence
1514
from ._dtypes import _DType, _all_dtypes
1615

1716
import numpy as np
@@ -22,6 +21,12 @@ def _check_valid_dtype(dtype):
2221
if dtype not in (None,) + _all_dtypes:
2322
raise ValueError("dtype must be one of the supported dtypes")
2423

24+
def _supports_buffer_protocol(obj):
25+
try:
26+
memoryview(obj)
27+
except TypeError:
28+
return False
29+
return True
2530

2631
def asarray(
2732
obj: Union[
@@ -36,7 +41,7 @@ def asarray(
3641
*,
3742
dtype: Optional[Dtype] = None,
3843
device: Optional[Device] = None,
39-
copy: Optional[Union[bool, np._CopyMode]] = None,
44+
copy: Optional[bool] = None,
4045
) -> Array:
4146
"""
4247
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
@@ -53,20 +58,37 @@ def asarray(
5358
_np_dtype = dtype._np_dtype
5459
if device not in [CPU_DEVICE, None]:
5560
raise ValueError(f"Unsupported device {device!r}")
56-
if copy in (False, np._CopyMode.IF_NEEDED):
57-
# Note: copy=False is not yet implemented in np.asarray
58-
raise NotImplementedError("copy=False is not yet implemented")
61+
62+
if np.__version__[0] < '2':
63+
if copy is False:
64+
# Note: copy=False is not yet implemented in np.asarray for
65+
# NumPy 1
66+
67+
# Work around it by creating the new array and seeing if NumPy
68+
# copies it.
69+
if isinstance(obj, Array):
70+
new_array = np.array(obj._array, copy=copy, dtype=_np_dtype)
71+
if new_array is not obj._array:
72+
raise ValueError("Unable to avoid copy while creating an array from given array.")
73+
return Array._new(new_array)
74+
elif _supports_buffer_protocol(obj):
75+
# Buffer protocol will always support no-copy
76+
return Array._new(np.array(obj, copy=copy, dtype=_np_dtype))
77+
else:
78+
# No-copy is unsupported for Python built-in types.
79+
raise ValueError("Unable to avoid copy while creating an array from given object.")
80+
81+
if copy is None:
82+
# NumPy 1 treats copy=False the same as the standard copy=None
83+
copy = False
84+
5985
if isinstance(obj, Array):
60-
if dtype is not None and obj.dtype != dtype:
61-
copy = True
62-
if copy in (True, np._CopyMode.ALWAYS):
63-
return Array._new(np.array(obj._array, copy=True, dtype=_np_dtype))
64-
return obj
86+
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype))
6587
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
6688
# Give a better error message in this case. NumPy would convert this
6789
# to an object array. TODO: This won't handle large integers in lists.
6890
raise OverflowError("Integer out of bounds for array dtypes")
69-
res = np.asarray(obj, dtype=_np_dtype)
91+
res = np.array(obj, dtype=_np_dtype, copy=copy)
7092
return Array._new(res)
7193

7294

Diff for: array_api_strict/tests/test_creation_functions.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,49 @@ def test_asarray_copy():
5050
a[0] = 0
5151
assert all(b[0] == 1)
5252
assert all(a[0] == 0)
53+
5354
a = asarray([1])
54-
b = asarray(a, copy=np._CopyMode.ALWAYS)
55+
b = asarray(a, copy=False)
5556
a[0] = 0
56-
assert all(b[0] == 1)
57-
assert all(a[0] == 0)
57+
assert all(b[0] == 0)
58+
59+
a = asarray([1])
60+
assert_raises(ValueError, lambda: asarray(a, copy=False, dtype=float64))
61+
62+
a = asarray([1])
63+
b = asarray(a, copy=None)
64+
a[0] = 0
65+
assert all(b[0] == 0)
66+
5867
a = asarray([1])
59-
b = asarray(a, copy=np._CopyMode.NEVER)
68+
b = asarray(a, dtype=float64, copy=None)
69+
a[0] = 0
70+
assert all(b[0] == 1.0)
71+
72+
# Python built-in types
73+
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
74+
asarray(obj, copy=True) # No error
75+
asarray(obj, copy=None) # No error
76+
assert_raises(ValueError, lambda: asarray(obj, copy=False))
77+
78+
# Buffer protocol
79+
a = np.array([1])
80+
b = asarray(a, copy=True)
81+
assert isinstance(b, Array)
82+
a[0] = 0
83+
assert all(b[0] == 1)
84+
85+
a = np.array([1])
86+
b = asarray(a, copy=False)
87+
assert isinstance(b, Array)
6088
a[0] = 0
6189
assert all(b[0] == 0)
62-
assert_raises(NotImplementedError, lambda: asarray(a, copy=False))
63-
assert_raises(NotImplementedError,
64-
lambda: asarray(a, copy=np._CopyMode.IF_NEEDED))
6590

91+
a = np.array([1])
92+
b = asarray(a, copy=None)
93+
assert isinstance(b, Array)
94+
a[0] = 0
95+
assert all(b[0] == 0)
6696

6797
def test_arange_errors():
6898
arange(1, device=CPU_DEVICE) # Doesn't error

0 commit comments

Comments
 (0)