Skip to content

Commit fef8607

Browse files
committed
BUG: astype(..., copy=True) doesn't copy on dask
1 parent e5dd419 commit fef8607

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

array_api_compat/dask/array/_aliases.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@
3939

4040
isdtype = get_xp(np)(_aliases.isdtype)
4141
unstack = get_xp(da)(_aliases.unstack)
42-
astype = _aliases.astype
42+
43+
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
44+
if not copy and dtype == x.dtype:
45+
return x
46+
# dask astype doesn't respect copy=True so copy
47+
# manually via numpy
48+
x = np.array(x, dtype=dtype, copy=copy)
49+
return da.from_array(x)
4350

4451
# Common aliases
4552

tests/test_common.py

+13
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,16 @@ def test_asarray_copy(library):
272272
assert all(b[0] == 1.0)
273273
else:
274274
assert all(b[0] == 0.0)
275+
276+
@pytest.mark.parametrize("library", wrapped_libraries)
277+
def test_astype_copy(library):
278+
# array-api-tests currently doesn't check copy=True
279+
# makes a copy when dtypes are the same
280+
# so we check that here
281+
xp = import_(library, wrapper=True)
282+
a = xp.asarray([1])
283+
b = xp.astype(a, a.dtype, copy=True)
284+
285+
a[0] = 10
286+
287+
assert b[0] == 1

0 commit comments

Comments
 (0)