Skip to content

Commit 5aa1333

Browse files
lithomas1crusaderky
andcommitted
address feedback
Co-Authored-By: Guido Imperiale <[email protected]>
1 parent fef8607 commit 5aa1333

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

array_api_compat/dask/array/_aliases.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,25 @@
4040
isdtype = get_xp(np)(_aliases.isdtype)
4141
unstack = get_xp(da)(_aliases.unstack)
4242

43-
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
43+
def astype(
44+
x: Array,
45+
dtype: Dtype,
46+
/,
47+
*,
48+
copy: bool = True,
49+
device: Device | None = None
50+
) -> Array:
51+
if device is not None:
52+
raise NotImplementedError(
53+
"The device keyword is not implemented yet for "
54+
"array-api-compat wrapped dask"
55+
)
4456
if not copy and dtype == x.dtype:
4557
return x
46-
# dask astype doesn't respect copy=True so copy
47-
# manually via numpy
48-
x = np.array(x, dtype=dtype, copy=copy)
49-
return da.from_array(x)
58+
# dask astype doesn't respect copy=True,
59+
# so call copy manually afterwards
60+
x = x.astype(dtype)
61+
return x.copy() if copy else x
5062

5163
# Common aliases
5264

tests/test_common.py

-13
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,3 @@ def test_asarray_copy(library):
272272
assert all(b[0] == 1.0)
273273
else:
274274
assert all(b[0] == 0.0)
275-
276-
@pytest.mark.parametrize("library", wrapped_libraries)
277-
def test_astype_copy(library):
278-
# array-api-tests currently doesn't check copy=True
279-
# makes a copy when dtypes are the same
280-
# so we check that here
281-
xp = import_(library, wrapper=True)
282-
a = xp.asarray([1])
283-
b = xp.astype(a, a.dtype, copy=True)
284-
285-
a[0] = 10
286-
287-
assert b[0] == 1

0 commit comments

Comments
 (0)