@@ -268,7 +268,6 @@ def test_asarray_cross_library(source_library, target_library, request):
268
268
assert b .dtype == tgt_lib .int32
269
269
270
270
271
-
272
271
@pytest .mark .parametrize ("library" , wrapped_libraries )
273
272
def test_asarray_copy (library ):
274
273
# Note, we have this test here because the test suite currently doesn't
@@ -278,100 +277,87 @@ def test_asarray_copy(library):
278
277
xp = import_ (library , wrapper = True )
279
278
asarray = xp .asarray
280
279
is_lib_func = globals ()[is_array_functions [library ]]
281
- all = xp .all if library != 'dask.array' else lambda x : xp .all (x ).compute ()
282
-
283
- if library == 'cupy' :
284
- supports_copy_false_other_ns = False
285
- supports_copy_false_same_ns = False
286
- elif library == 'dask.array' :
287
- supports_copy_false_other_ns = False
288
- supports_copy_false_same_ns = True
289
- else :
290
- supports_copy_false_other_ns = True
291
- supports_copy_false_same_ns = True
292
280
293
281
a = asarray ([1 ])
294
282
b = asarray (a , copy = True )
295
283
assert is_lib_func (b )
296
284
a [0 ] = 0
297
- assert all ( b [0 ] == 1 )
298
- assert all ( a [0 ] == 0 )
285
+ assert b [0 ] == 1
286
+ assert a [0 ] == 0
299
287
300
288
a = asarray ([1 ])
301
- if supports_copy_false_same_ns :
302
- b = asarray (a , copy = False )
303
- assert is_lib_func (b )
304
- a [0 ] = 0
305
- assert all (b [0 ] == 0 )
306
- else :
307
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
308
289
309
- a = asarray ([1 ])
310
- if supports_copy_false_same_ns :
311
- pytest .raises (ValueError , lambda : asarray (a , copy = False ,
312
- dtype = xp .float64 ))
313
- else :
314
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False , dtype = xp .float64 ))
290
+ # Test copy=False within the same namespace
291
+ b = asarray (a , copy = False )
292
+ assert is_lib_func (b )
293
+ a [0 ] = 0
294
+ assert b [0 ] == 0
295
+ with pytest .raises (ValueError ):
296
+ asarray (a , copy = False , dtype = xp .float64 )
315
297
298
+ # copy=None defaults to False when possible
316
299
a = asarray ([1 ])
317
300
b = asarray (a , copy = None )
318
301
assert is_lib_func (b )
319
302
a [0 ] = 0
320
- assert all ( b [0 ] == 0 )
303
+ assert b [0 ] == 0
321
304
305
+ # copy=None defaults to True when impossible
322
306
a = asarray ([1.0 ], dtype = xp .float32 )
323
307
assert a .dtype == xp .float32
324
308
b = asarray (a , dtype = xp .float64 , copy = None )
325
309
assert is_lib_func (b )
326
310
assert b .dtype == xp .float64
327
311
a [0 ] = 0.0
328
- assert all ( b [0 ] == 1.0 )
312
+ assert b [0 ] == 1.0
329
313
314
+ # copy=None defaults to False when possible
330
315
a = asarray ([1.0 ], dtype = xp .float64 )
331
316
assert a .dtype == xp .float64
332
317
b = asarray (a , dtype = xp .float64 , copy = None )
333
318
assert is_lib_func (b )
334
319
assert b .dtype == xp .float64
335
320
a [0 ] = 0.0
336
- assert all ( b [0 ] == 0.0 )
321
+ assert b [0 ] == 0.0
337
322
338
323
# Python built-in types
339
324
for obj in [True , 0 , 0.0 , 0j , [0 ], [[0 ]]]:
340
- asarray (obj , copy = True ) # No error
341
- asarray (obj , copy = None ) # No error
342
- if supports_copy_false_other_ns :
343
- pytest .raises (ValueError , lambda : asarray (obj , copy = False ))
344
- else :
345
- pytest .raises (NotImplementedError , lambda : asarray (obj , copy = False ))
325
+ asarray (obj , copy = True ) # No error
326
+ asarray (obj , copy = None ) # No error
327
+
328
+ with pytest .raises (ValueError ):
329
+ asarray (obj , copy = False )
346
330
347
331
# Use the standard library array to test the buffer protocol
348
- a = array .array ('f' , [1.0 ])
332
+ a = array .array ("f" , [1.0 ])
349
333
b = asarray (a , copy = True )
350
334
assert is_lib_func (b )
351
335
a [0 ] = 0.0
352
- assert all ( b [0 ] == 1.0 )
336
+ assert b [0 ] == 1.0
353
337
354
- a = array .array ('f' , [1.0 ])
355
- if supports_copy_false_other_ns :
338
+ a = array .array ("f" , [1.0 ])
339
+ if library in ("cupy" , "dask.array" ):
340
+ with pytest .raises (ValueError ):
341
+ asarray (a , copy = False )
342
+ else :
356
343
b = asarray (a , copy = False )
357
344
assert is_lib_func (b )
358
345
a [0 ] = 0.0
359
- assert all (b [0 ] == 0.0 )
360
- else :
361
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
346
+ assert b [0 ] == 0.0
362
347
363
- a = array .array ('f' , [1.0 ])
348
+ a = array .array ("f" , [1.0 ])
364
349
b = asarray (a , copy = None )
365
350
assert is_lib_func (b )
366
351
a [0 ] = 0.0
367
- if library in (' cupy' , ' dask.array' ):
352
+ if library in (" cupy" , " dask.array" ):
368
353
# A copy is required for libraries where the default device is not CPU
369
354
# dask changed behaviour of copy=None in 2024.12 to copy;
370
355
# this wrapper ensures the same behaviour in older versions too.
371
356
# https://github.com/dask/dask/pull/11524/
372
- assert all ( b [0 ] == 1.0 )
357
+ assert b [0 ] == 1.0
373
358
else :
374
- assert all (b [0 ] == 0.0 )
359
+ # copy=None defaults to False when possible
360
+ assert b [0 ] == 0.0
375
361
376
362
377
363
@pytest .mark .parametrize ("library" , ["numpy" , "cupy" , "torch" ])
0 commit comments