From fef8607acdb4e8ba5b38ce7040aaf96dd219a345 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 14 Jan 2025 11:03:04 -0800 Subject: [PATCH 1/3] BUG: astype(..., copy=True) doesn't copy on dask --- array_api_compat/dask/array/_aliases.py | 9 ++++++++- tests/test_common.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index df8fede8..2cd792f9 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -39,7 +39,14 @@ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) -astype = _aliases.astype + +def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: + if not copy and dtype == x.dtype: + return x + # dask astype doesn't respect copy=True so copy + # manually via numpy + x = np.array(x, dtype=dtype, copy=copy) + return da.from_array(x) # Common aliases diff --git a/tests/test_common.py b/tests/test_common.py index 1a4a32dc..0bde5ebe 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -272,3 +272,16 @@ def test_asarray_copy(library): assert all(b[0] == 1.0) else: assert all(b[0] == 0.0) + +@pytest.mark.parametrize("library", wrapped_libraries) +def test_astype_copy(library): + # array-api-tests currently doesn't check copy=True + # makes a copy when dtypes are the same + # so we check that here + xp = import_(library, wrapper=True) + a = xp.asarray([1]) + b = xp.astype(a, a.dtype, copy=True) + + a[0] = 10 + + assert b[0] == 1 From 5aa13331584380a1fd5b65715cdeb4486d667311 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 15 Jan 2025 13:44:03 -0800 Subject: [PATCH 2/3] address feedback Co-Authored-By: Guido Imperiale <6213168+crusaderky@users.noreply.github.com> --- array_api_compat/dask/array/_aliases.py | 22 +++++++++++++++++----- tests/test_common.py | 13 ------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 2cd792f9..cd40d355 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -40,13 +40,25 @@ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) -def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: +def astype( + x: Array, + dtype: Dtype, + /, + *, + copy: bool = True, + device: Device | None = None +) -> Array: + if device is not None: + raise NotImplementedError( + "The device keyword is not implemented yet for " + "array-api-compat wrapped dask" + ) if not copy and dtype == x.dtype: return x - # dask astype doesn't respect copy=True so copy - # manually via numpy - x = np.array(x, dtype=dtype, copy=copy) - return da.from_array(x) + # dask astype doesn't respect copy=True, + # so call copy manually afterwards + x = x.astype(dtype) + return x.copy() if copy else x # Common aliases diff --git a/tests/test_common.py b/tests/test_common.py index 0bde5ebe..1a4a32dc 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -272,16 +272,3 @@ def test_asarray_copy(library): assert all(b[0] == 1.0) else: assert all(b[0] == 0.0) - -@pytest.mark.parametrize("library", wrapped_libraries) -def test_astype_copy(library): - # array-api-tests currently doesn't check copy=True - # makes a copy when dtypes are the same - # so we check that here - xp = import_(library, wrapper=True) - a = xp.asarray([1]) - b = xp.astype(a, a.dtype, copy=True) - - a[0] = 10 - - assert b[0] == 1 From 691c27bfd0c353204260c4f3a6f7052324a900a0 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 15 Jan 2025 15:03:41 -0800 Subject: [PATCH 3/3] ignore device --- array_api_compat/dask/array/_aliases.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index cd40d355..08514717 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -48,11 +48,7 @@ def astype( copy: bool = True, device: Device | None = None ) -> Array: - if device is not None: - raise NotImplementedError( - "The device keyword is not implemented yet for " - "array-api-compat wrapped dask" - ) + # TODO: respect device keyword? if not copy and dtype == x.dtype: return x # dask astype doesn't respect copy=True,