Skip to content

Commit 35e556c

Browse files
committed
Fix dask
1 parent 645f9a8 commit 35e556c

File tree

6 files changed

+36
-10
lines changed

6 files changed

+36
-10
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# macOS specific iles
132+
.DS_Store

README.md

+26-2
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ part of the specification but which are useful for using the array API:
125125
[`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html)
126126
in the array API specification. Included because `numpy.ndarray` does not
127127
include the `device` attribute and this library does not wrap or extend the
128-
array object. Note that for NumPy, `device(x)` is always `"cpu"`.
128+
array object. Note that for NumPy and dask, `device(x)` is always `"cpu"`.
129129

130130
- `to_device(x, device, /, *, stream=None)`: Equivalent to
131131
[`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html).
132-
Included because neither NumPy's, CuPy's, nor PyTorch's array objects
132+
Included because neither NumPy's, CuPy's, Dask's, nor PyTorch's array objects
133133
include this method. For NumPy, this function effectively does nothing since
134134
the only supported device is the CPU, but for CuPy, this method supports
135135
CuPy CUDA
@@ -240,6 +240,30 @@ Unlike the other libraries supported here, JAX array API support is contained
240240
entirely in the JAX library. The JAX array API support is tracked at
241241
https://github.com/google/jax/issues/18353.
242242

243+
## Dask
244+
245+
If you're using dask with numpy, many of the same limitations that apply to numpy
246+
will also apply to dask. Besides those differences, other limitations include missing
247+
sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg`
248+
and `fft` extensions.
249+
250+
In particular, the `fft` namespace is not compliant with the array API spec. Any functions
251+
that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk!
252+
253+
For `linalg`, several methods are missing, for example:
254+
- `cross`
255+
- `det`
256+
- `eigh`
257+
- `eigvalsh`
258+
- `matrix_power`
259+
- `pinv`
260+
- `slogdet`
261+
- `matrix_norm`
262+
- `matrix_rank`
263+
Other methods may only be partially implemented or return incorrect results at times.
264+
265+
The minimum supported Dask version is 2023.12.0.
266+
243267
## Vendoring
244268

245269
This library supports vendoring as an installation method. To vendor the

array_api_compat/common/_helpers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ def device(x: Array, /) -> Device:
179179
out: device
180180
a ``device`` object (see the "Device Support" section of the array API specification).
181181
"""
182-
if is_numpy_array(x):
182+
if is_numpy_array(x) or is_dask_array(x):
183+
# TODO: dask technically can support GPU arrays
184+
# Detecting the array backend isn't easy for dask, though, so just return CPU for now
183185
return "cpu"
184186
if is_jax_array(x):
185187
# JAX has .device() as a method, but it is being deprecated so that it

array_api_compat/dask/array/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from dask.array import (
2323
arctanh as atanh,
2424
)
25-
from dask.array import (
25+
from numpy import (
2626
bool_ as bool,
2727
)
2828
from dask.array import (
@@ -67,15 +67,15 @@
6767
uint64,
6868
)
6969

70-
from ..common._helpers import (
70+
from ...common._helpers import (
7171
array_namespace,
7272
device,
7373
get_namespace,
7474
is_array_api_obj,
7575
size,
7676
to_device,
7777
)
78-
from ..internal import _get_all_public_members
78+
from ..._internal import _get_all_public_members
7979
from ._aliases import (
8080
UniqueAllResult,
8181
UniqueCountsResult,

array_api_compat/dask/array/linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
)
88
from dask.array.linalg import * # noqa: F401, F403
99

10-
from .._internal import _get_all_public_members
10+
from ..._internal import _get_all_public_members
1111
from ._aliases import (
1212
EighResult,
1313
QRResult,

tests/test_common.py

-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ def test_is_xp_array(library, func):
3131

3232
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
3333
def test_device(library):
34-
if library == "dask.array":
35-
pytest.xfail("device() needs to be fixed for dask")
36-
3734
xp = import_(library, wrapper=True)
3835

3936
# We can't test much for device() and to_device() other than that

0 commit comments

Comments
 (0)