Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support changed asarray behaviour in dask 2023.12.0 #214

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,32 @@ def asarray(

See the corresponding documentation in the array library and/or the array API
specification for more details.

.. note::
copy=True means that if you update the output array the input will never
be affected; however the output array may internally hold references to the
input array, preventing deallocation. This kind of implementation detail should
be left at dask's discretion.
"""
if copy is False:
# copy=False is not yet implemented in dask
raise NotImplementedError("copy=False is not yet implemented")
elif copy is True:
if isinstance(obj, da.Array) and dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
obj = np.array(obj, dtype=dtype, copy=True)
return da.array(obj, dtype=dtype)
else:
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj
raise NotImplementedError("copy=False can't be implemented in dask")

if (
copy is True
and isinstance(obj, da.Array)
and (dtype is None or dtype == obj.dtype)
):
return obj.copy()

obj = da.asarray(obj, dtype=dtype)

# Backport https://github.com/dask/dask/pull/11586
if dtype not in (None, obj.dtype):
obj = obj.astype(dtype)

return obj

return da.asarray(obj, dtype=dtype, **kwargs)

from dask.array import (
# Element wise aliases
Expand Down
49 changes: 34 additions & 15 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def test_to_device_host(library):
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
if (
(source_library == "dask.array" and target_library == "torch")
or (source_library == "torch" and target_library == "dask.array")
):
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902

Expand All @@ -112,6 +115,7 @@ def test_asarray_cross_library(source_library, target_library, request):

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"


@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
# Note, we have this test here because the test suite currently doesn't
Expand All @@ -130,41 +134,57 @@ def test_asarray_copy(library):
else:
supports_copy_false = True

# Tests for copy=True
a = asarray([1])
b = asarray(a, copy=True)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 1)
assert all(a[0] == 0)

a = asarray([1])
b = asarray(a, copy=True, dtype=a.dtype)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 1)
assert all(a[0] == 0)

# Tests for copy=False
a = asarray([1])
if supports_copy_false:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 0)
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
with pytest.raises(NotImplementedError):
asarray(a, copy=False)

a = asarray([1])
if supports_copy_false:
pytest.raises(ValueError, lambda: asarray(a, copy=False,
dtype=xp.float64))
with pytest.raises(ValueError):
asarray(a, copy=False, dtype=xp.float64)
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
with pytest.raises(NotImplementedError):
asarray(a, copy=False, dtype=xp.float64)

# Tests for copy=None
# Do not test whether the buffer is shared or not after copy=None.
# A library should have the freedom to alter its behaviour
# without treating it as a breaking change.
a = asarray([1])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 0)
assert all((b[0] == 1.0) | (b[0] == 0.0))

a = asarray([1.0], dtype=xp.float32)
assert a.dtype == xp.float32
b = asarray(a, dtype=xp.float64, copy=None)
assert is_lib_func(b)
assert b.dtype == xp.float64
a[0] = 0.0
# dtype change must always trigger a copy
assert all(b[0] == 1.0)

a = asarray([1.0], dtype=xp.float64)
Expand All @@ -173,16 +193,18 @@ def test_asarray_copy(library):
assert is_lib_func(b)
assert b.dtype == xp.float64
a[0] = 0.0
assert all(b[0] == 0.0)
assert all((b[0] == 1.0) | (b[0] == 0.0))

# Python built-in types
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
asarray(obj, copy=True) # No error
asarray(obj, copy=None) # No error
if supports_copy_false:
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
with pytest.raises(ValueError):
asarray(obj, copy=False)
else:
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
with pytest.raises(NotImplementedError):
asarray(obj, copy=False)

# Use the standard library array to test the buffer protocol
a = array.array('f', [1.0])
Expand All @@ -198,14 +220,11 @@ def test_asarray_copy(library):
a[0] = 0.0
assert all(b[0] == 0.0)
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
with pytest.raises(NotImplementedError):
asarray(a, copy=False)

a = array.array('f', [1.0])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0.0
if library == 'cupy':
# A copy is required for libraries where the default device is not CPU
assert all(b[0] == 1.0)
else:
assert all(b[0] == 0.0)
assert all((b[0] == 1.0) | (b[0] == 0.0))
Loading