diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index ee2d88c0..4371f769 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -128,23 +128,32 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. + + .. note:: + copy=True means that if you update the output array the input will never + be affected; however the output array may internally hold references to the + input array, preventing deallocation. This kind of implementation detail should + be left at dask's discretion. """ if copy is False: # copy=False is not yet implemented in dask - raise NotImplementedError("copy=False is not yet implemented") - elif copy is True: - if isinstance(obj, da.Array) and dtype is None: - return obj.copy() - # Go through numpy, since dask copy is no-op by default - obj = np.array(obj, dtype=dtype, copy=True) - return da.array(obj, dtype=dtype) - else: - if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype: - obj = np.asarray(obj, dtype=dtype) - return da.from_array(obj) - return obj + raise NotImplementedError("copy=False can't be implemented in dask") + + if ( + copy is True + and isinstance(obj, da.Array) + and (dtype is None or dtype == obj.dtype) + ): + return obj.copy() + + obj = da.asarray(obj, dtype=dtype) + + # Backport https://github.com/dask/dask/pull/11586 + if dtype not in (None, obj.dtype): + obj = obj.astype(dtype) + + return obj - return da.asarray(obj, dtype=dtype, **kwargs) from dask.array import ( # Element wise aliases diff --git a/tests/test_common.py b/tests/test_common.py index e1cfa9eb..199dbc3b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -91,7 +91,10 @@ def test_to_device_host(library): @pytest.mark.parametrize("target_library", is_array_functions.keys()) @pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): - if source_library == "dask.array" and target_library == "torch": + if ( + (source_library == "dask.array" and target_library == "torch") + or (source_library == "torch" and target_library == "dask.array") + ): # Allow rest of test to execute instead of immediately xfailing # xref https://github.com/pandas-dev/pandas/issues/38902 @@ -112,6 +115,7 @@ def test_asarray_cross_library(source_library, target_library, request): assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" + @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't @@ -130,6 +134,7 @@ def test_asarray_copy(library): else: supports_copy_false = True + # Tests for copy=True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) @@ -137,6 +142,14 @@ def test_asarray_copy(library): assert all(b[0] == 1) assert all(a[0] == 0) + a = asarray([1]) + b = asarray(a, copy=True, dtype=a.dtype) + assert is_lib_func(b) + a[0] = 0 + assert all(b[0] == 1) + assert all(a[0] == 0) + + # Tests for copy=False a = asarray([1]) if supports_copy_false: b = asarray(a, copy=False) @@ -144,20 +157,26 @@ def test_asarray_copy(library): a[0] = 0 assert all(b[0] == 0) else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + with pytest.raises(NotImplementedError): + asarray(a, copy=False) a = asarray([1]) if supports_copy_false: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + with pytest.raises(NotImplementedError): + asarray(a, copy=False, dtype=xp.float64) + # Tests for copy=None + # Do not test whether the buffer is shared or not after copy=None. + # A library should have the freedom to alter its behaviour + # without treating it as a breaking change. a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert all((b[0] == 1.0) | (b[0] == 0.0)) a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 @@ -165,6 +184,7 @@ def test_asarray_copy(library): assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 + # dtype change must always trigger a copy assert all(b[0] == 1.0) a = asarray([1.0], dtype=xp.float64) @@ -173,16 +193,18 @@ def test_asarray_copy(library): assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert all((b[0] == 1.0) | (b[0] == 0.0)) # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: asarray(obj, copy=True) # No error asarray(obj, copy=None) # No error if supports_copy_false: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) + with pytest.raises(ValueError): + asarray(obj, copy=False) else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + with pytest.raises(NotImplementedError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol a = array.array('f', [1.0]) @@ -198,14 +220,11 @@ def test_asarray_copy(library): a[0] = 0.0 assert all(b[0] == 0.0) else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + with pytest.raises(NotImplementedError): + asarray(a, copy=False) a = array.array('f', [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library == 'cupy': - # A copy is required for libraries where the default device is not CPU - assert all(b[0] == 1.0) - else: - assert all(b[0] == 0.0) + assert all((b[0] == 1.0) | (b[0] == 0.0))