-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dask #89
Fix dask #89
Changes from all commits
35e556c
9afed29
327a1e2
6bcc4a9
e9e740f
a4f1b2c
3f06837
924d297
2e4c796
72919ed
23eb764
1f7e47c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,6 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# macOS specific iles | ||
.DS_Store |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,7 +159,16 @@ def _check_device(xp, device): | |
if device not in ["cpu", None]: | ||
raise ValueError(f"Unsupported device for NumPy: {device!r}") | ||
|
||
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray | ||
# Placeholder object to represent the dask device | ||
# when the array backend is not the CPU. | ||
# (since it is not easy to tell which device a dask array is on) | ||
class _dask_device: | ||
def __repr__(self): | ||
return "DASK_DEVICE" | ||
|
||
_DASK_DEVICE = _dask_device() | ||
|
||
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray | ||
# or cupy.ndarray. They are not included in array objects of this library | ||
# because this library just reuses the respective ndarray classes without | ||
# wrapping or subclassing them. These helper functions can be used instead of | ||
|
@@ -181,7 +190,17 @@ def device(x: Array, /) -> Device: | |
""" | ||
if is_numpy_array(x): | ||
return "cpu" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I noted on the other PR, it would probably be better to use some kind of basic DaskDevice object here instead of the string "cpu", given that CPU isn't necessarily an accurate description of the device dask is running on. See https://github.com/data-apis/array-api-strict/blob/main/array_api_strict/_array_object.py#L43-L49 for example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I return cpu now only if the type of the array backing the dask array is a ndarray. The rest of the time, I return a DaskDevice. Is this something close to what you wanted? (We might be able to do this for cupy, but it's tricky for e.g. multigpu cases I guess) |
||
if is_jax_array(x): | ||
elif is_dask_array(x): | ||
# Peek at the metadata of the jax array to determine type | ||
try: | ||
import numpy as np | ||
if isinstance(x._meta, np.ndarray): | ||
# Must be on CPU since backed by numpy | ||
return "cpu" | ||
except ImportError: | ||
pass | ||
return _DASK_DEVICE | ||
elif is_jax_array(x): | ||
# JAX has .device() as a method, but it is being deprecated so that it | ||
# can become a property, in accordance with the standard. In order for | ||
# this function to not break when JAX makes the flip, we check for | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of interest, is the plan to implement
array_api_compat.dask.{fft, linalg}
or wait for support fromdask
itself? A similar question w.r.t JAX.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't attempted to wrap fft yet - waiting on #78 to do so.
Linalg can only be partially supported by us since there's missing methods in dask.