7
7
8
8
import pytest
9
9
import numpy as np
10
+ import array
10
11
from numpy .testing import assert_allclose
11
12
12
13
is_functions = {
@@ -64,7 +65,7 @@ def test_to_device_host(library):
64
65
65
66
@pytest .mark .parametrize ("target_library,func" , is_functions .items ())
66
67
@pytest .mark .parametrize ("source_library" , is_functions .keys ())
67
- def test_asarray (source_library , target_library , func , request ):
68
+ def test_asarray_cross_library (source_library , target_library , func , request ):
68
69
if source_library == "dask.array" and target_library == "torch" :
69
70
# Allow rest of test to execute instead of immediately xfailing
70
71
# xref https://github.com/pandas-dev/pandas/issues/38902
@@ -80,3 +81,67 @@ def test_asarray(source_library, target_library, func, request):
80
81
b = tgt_lib .asarray (a )
81
82
82
83
assert is_tgt_type (b ), f"Expected { b } to be a { tgt_lib .ndarray } , but was { type (b )} "
84
+
85
+ @pytest .mark .parametrize ("library" , wrapped_libraries )
86
+ def test_asarray_copy (library ):
87
+ # Note, we have this test here because the test suite currently doesn't
88
+ # test the copy flag to asarray() very rigorously. Once
89
+ # https://github.com/data-apis/array-api-tests/issues/241 is fixed we
90
+ # should be able to delete this.
91
+ xp = import_ (library , wrapper = True )
92
+ asarray = xp .asarray
93
+ is_lib_func = globals ()[is_functions [library ]]
94
+ all = xp .all
95
+
96
+ a = asarray ([1 ])
97
+ b = asarray (a , copy = True )
98
+ assert is_lib_func (b )
99
+ a [0 ] = 0
100
+ assert all (b [0 ] == 1 )
101
+ assert all (a [0 ] == 0 )
102
+
103
+ a = asarray ([1 ])
104
+ b = asarray (a , copy = False )
105
+ assert is_lib_func (b )
106
+ a [0 ] = 0
107
+ assert all (b [0 ] == 0 )
108
+
109
+ a = asarray ([1 ])
110
+ pytest .raises (ValueError , lambda : asarray (a , copy = False , dtype = xp .float64 ))
111
+
112
+ a = asarray ([1 ])
113
+ b = asarray (a , copy = None )
114
+ assert is_lib_func (b )
115
+ a [0 ] = 0
116
+ assert all (b [0 ] == 0 )
117
+
118
+ a = asarray ([1 ])
119
+ b = asarray (a , dtype = xp .float64 , copy = None )
120
+ assert is_lib_func (b )
121
+ a [0 ] = 0
122
+ assert all (b [0 ] == 1 )
123
+
124
+ # Python built-in types
125
+ for obj in [True , 0 , 0.0 , 0j , [0 ], [[0 ]]]:
126
+ asarray (obj , copy = True ) # No error
127
+ asarray (obj , copy = None ) # No error
128
+ pytest .raises (ValueError , lambda : asarray (obj , copy = False ))
129
+
130
+ # Use the standard library array to test the buffer protocol
131
+ a = array .array ('f' , [1.0 ])
132
+ b = asarray (a , copy = True )
133
+ assert is_lib_func (b )
134
+ a [0 ] = 0.0
135
+ assert all (b [0 ] == 1.0 )
136
+
137
+ a = array .array ('f' , [1.0 ])
138
+ b = asarray (a , copy = False )
139
+ assert is_lib_func (b )
140
+ a [0 ] = 0.0
141
+ assert all (b [0 ] == 0.0 )
142
+
143
+ a = array .array ('f' , [1.0 ])
144
+ b = asarray (a , copy = None )
145
+ assert is_lib_func (b )
146
+ a [0 ] = 0.0
147
+ assert all (b [0 ] == 0.0 )
0 commit comments