-
Notifications
You must be signed in to change notification settings - Fork 33
Fix dask #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dask #89
Changes from 1 commit
35e556c
9afed29
327a1e2
6bcc4a9
e9e740f
a4f1b2c
3f06837
924d297
2e4c796
72919ed
23eb764
1f7e47c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,6 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# macOS specific iles | ||
.DS_Store |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,7 +179,9 @@ def device(x: Array, /) -> Device: | |
out: device | ||
a ``device`` object (see the "Device Support" section of the array API specification). | ||
""" | ||
if is_numpy_array(x): | ||
if is_numpy_array(x) or is_dask_array(x): | ||
# TODO: dask technically can support GPU arrays | ||
# Detecting the array backend isn't easy for dask, though, so just return CPU for now | ||
return "cpu" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I noted on the other PR, it would probably be better to use some kind of basic DaskDevice object here instead of the string "cpu", given that CPU isn't necessarily an accurate description of the device dask is running on. See https://github.com/data-apis/array-api-strict/blob/main/array_api_strict/_array_object.py#L43-L49 for example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I return cpu now only if the type of the array backing the dask array is a ndarray. The rest of the time, I return a DaskDevice. Is this something close to what you wanted? (We might be able to do this for cupy, but it's tricky for e.g. multigpu cases I guess) |
||
if is_jax_array(x): | ||
# JAX has .device() as a method, but it is being deprecated so that it | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of interest, is the plan to implement
array_api_compat.dask.{fft, linalg}
or wait for support fromdask
itself? A similar question w.r.t JAX.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't attempted to wrap fft yet - waiting on #78 to do so.
Linalg can only be partially supported by us since there's missing methods in dask.