Skip to content

Commit 559c82b

Browse files
authored
Merge pull request #325 from crusaderky/asarray_copy
TST: revisit test for `asarray` `copy=` parameter
2 parents 1b4ba64 + 44e7828 commit 559c82b

File tree

2 files changed

+34
-48
lines changed

2 files changed

+34
-48
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def asarray(
171171
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
172172

173173
if copy is False:
174-
raise NotImplementedError(
174+
raise ValueError(
175175
"Unable to avoid copy when converting a non-dask object to dask"
176176
)
177177

tests/test_common.py

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ def test_asarray_cross_library(source_library, target_library, request):
268268
assert b.dtype == tgt_lib.int32
269269

270270

271-
272271
@pytest.mark.parametrize("library", wrapped_libraries)
273272
def test_asarray_copy(library):
274273
# Note, we have this test here because the test suite currently doesn't
@@ -278,100 +277,87 @@ def test_asarray_copy(library):
278277
xp = import_(library, wrapper=True)
279278
asarray = xp.asarray
280279
is_lib_func = globals()[is_array_functions[library]]
281-
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
282-
283-
if library == 'cupy':
284-
supports_copy_false_other_ns = False
285-
supports_copy_false_same_ns = False
286-
elif library == 'dask.array':
287-
supports_copy_false_other_ns = False
288-
supports_copy_false_same_ns = True
289-
else:
290-
supports_copy_false_other_ns = True
291-
supports_copy_false_same_ns = True
292280

293281
a = asarray([1])
294282
b = asarray(a, copy=True)
295283
assert is_lib_func(b)
296284
a[0] = 0
297-
assert all(b[0] == 1)
298-
assert all(a[0] == 0)
285+
assert b[0] == 1
286+
assert a[0] == 0
299287

300288
a = asarray([1])
301-
if supports_copy_false_same_ns:
302-
b = asarray(a, copy=False)
303-
assert is_lib_func(b)
304-
a[0] = 0
305-
assert all(b[0] == 0)
306-
else:
307-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
308289

309-
a = asarray([1])
310-
if supports_copy_false_same_ns:
311-
pytest.raises(ValueError, lambda: asarray(a, copy=False,
312-
dtype=xp.float64))
313-
else:
314-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
290+
# Test copy=False within the same namespace
291+
b = asarray(a, copy=False)
292+
assert is_lib_func(b)
293+
a[0] = 0
294+
assert b[0] == 0
295+
with pytest.raises(ValueError):
296+
asarray(a, copy=False, dtype=xp.float64)
315297

298+
# copy=None defaults to False when possible
316299
a = asarray([1])
317300
b = asarray(a, copy=None)
318301
assert is_lib_func(b)
319302
a[0] = 0
320-
assert all(b[0] == 0)
303+
assert b[0] == 0
321304

305+
# copy=None defaults to True when impossible
322306
a = asarray([1.0], dtype=xp.float32)
323307
assert a.dtype == xp.float32
324308
b = asarray(a, dtype=xp.float64, copy=None)
325309
assert is_lib_func(b)
326310
assert b.dtype == xp.float64
327311
a[0] = 0.0
328-
assert all(b[0] == 1.0)
312+
assert b[0] == 1.0
329313

314+
# copy=None defaults to False when possible
330315
a = asarray([1.0], dtype=xp.float64)
331316
assert a.dtype == xp.float64
332317
b = asarray(a, dtype=xp.float64, copy=None)
333318
assert is_lib_func(b)
334319
assert b.dtype == xp.float64
335320
a[0] = 0.0
336-
assert all(b[0] == 0.0)
321+
assert b[0] == 0.0
337322

338323
# Python built-in types
339324
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
340-
asarray(obj, copy=True) # No error
341-
asarray(obj, copy=None) # No error
342-
if supports_copy_false_other_ns:
343-
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
344-
else:
345-
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
325+
asarray(obj, copy=True) # No error
326+
asarray(obj, copy=None) # No error
327+
328+
with pytest.raises(ValueError):
329+
asarray(obj, copy=False)
346330

347331
# Use the standard library array to test the buffer protocol
348-
a = array.array('f', [1.0])
332+
a = array.array("f", [1.0])
349333
b = asarray(a, copy=True)
350334
assert is_lib_func(b)
351335
a[0] = 0.0
352-
assert all(b[0] == 1.0)
336+
assert b[0] == 1.0
353337

354-
a = array.array('f', [1.0])
355-
if supports_copy_false_other_ns:
338+
a = array.array("f", [1.0])
339+
if library in ("cupy", "dask.array"):
340+
with pytest.raises(ValueError):
341+
asarray(a, copy=False)
342+
else:
356343
b = asarray(a, copy=False)
357344
assert is_lib_func(b)
358345
a[0] = 0.0
359-
assert all(b[0] == 0.0)
360-
else:
361-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
346+
assert b[0] == 0.0
362347

363-
a = array.array('f', [1.0])
348+
a = array.array("f", [1.0])
364349
b = asarray(a, copy=None)
365350
assert is_lib_func(b)
366351
a[0] = 0.0
367-
if library in ('cupy', 'dask.array'):
352+
if library in ("cupy", "dask.array"):
368353
# A copy is required for libraries where the default device is not CPU
369354
# dask changed behaviour of copy=None in 2024.12 to copy;
370355
# this wrapper ensures the same behaviour in older versions too.
371356
# https://github.com/dask/dask/pull/11524/
372-
assert all(b[0] == 1.0)
357+
assert b[0] == 1.0
373358
else:
374-
assert all(b[0] == 0.0)
359+
# copy=None defaults to False when possible
360+
assert b[0] == 0.0
375361

376362

377363
@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"])

0 commit comments

Comments
 (0)