Skip to content

Commit 73047a9

Browse files
committed
Add a test for the copy flag in asarray
This test currently fails because this logic isn't implemented correctly for numpy/cupy/dask. It does pass for pytorch.
1 parent f1068a3 commit 73047a9

File tree

1 file changed

+66
-1
lines changed

1 file changed

+66
-1
lines changed

tests/test_common.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99
import numpy as np
10+
import array
1011
from numpy.testing import assert_allclose
1112

1213
is_functions = {
@@ -64,7 +65,7 @@ def test_to_device_host(library):
6465

6566
@pytest.mark.parametrize("target_library,func", is_functions.items())
6667
@pytest.mark.parametrize("source_library", is_functions.keys())
67-
def test_asarray(source_library, target_library, func, request):
68+
def test_asarray_cross_library(source_library, target_library, func, request):
6869
if source_library == "dask.array" and target_library == "torch":
6970
# Allow rest of test to execute instead of immediately xfailing
7071
# xref https://github.com/pandas-dev/pandas/issues/38902
@@ -80,3 +81,67 @@ def test_asarray(source_library, target_library, func, request):
8081
b = tgt_lib.asarray(a)
8182

8283
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
84+
85+
@pytest.mark.parametrize("library", wrapped_libraries)
86+
def test_asarray_copy(library):
87+
# Note, we have this test here because the test suite currently doesn't
88+
# test the copy flag to asarray() very rigorously. Once
89+
# https://github.com/data-apis/array-api-tests/issues/241 is fixed we
90+
# should be able to delete this.
91+
xp = import_(library, wrapper=True)
92+
asarray = xp.asarray
93+
is_lib_func = globals()[is_functions[library]]
94+
all = xp.all
95+
96+
a = asarray([1])
97+
b = asarray(a, copy=True)
98+
assert is_lib_func(b)
99+
a[0] = 0
100+
assert all(b[0] == 1)
101+
assert all(a[0] == 0)
102+
103+
a = asarray([1])
104+
b = asarray(a, copy=False)
105+
assert is_lib_func(b)
106+
a[0] = 0
107+
assert all(b[0] == 0)
108+
109+
a = asarray([1])
110+
pytest.raises(ValueError, lambda: asarray(a, copy=False, dtype=xp.float64))
111+
112+
a = asarray([1])
113+
b = asarray(a, copy=None)
114+
assert is_lib_func(b)
115+
a[0] = 0
116+
assert all(b[0] == 0)
117+
118+
a = asarray([1])
119+
b = asarray(a, dtype=xp.float64, copy=None)
120+
assert is_lib_func(b)
121+
a[0] = 0
122+
assert all(b[0] == 1)
123+
124+
# Python built-in types
125+
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
126+
asarray(obj, copy=True) # No error
127+
asarray(obj, copy=None) # No error
128+
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
129+
130+
# Use the standard library array to test the buffer protocol
131+
a = array.array('f', [1.0])
132+
b = asarray(a, copy=True)
133+
assert is_lib_func(b)
134+
a[0] = 0.0
135+
assert all(b[0] == 1.0)
136+
137+
a = array.array('f', [1.0])
138+
b = asarray(a, copy=False)
139+
assert is_lib_func(b)
140+
a[0] = 0.0
141+
assert all(b[0] == 0.0)
142+
143+
a = array.array('f', [1.0])
144+
b = asarray(a, copy=None)
145+
assert is_lib_func(b)
146+
a[0] = 0.0
147+
assert all(b[0] == 0.0)

0 commit comments

Comments
 (0)