Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: validate device on numpy and dask #301

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,11 +595,29 @@ def your_function(x, y):
# backwards compatibility alias
get_namespace = array_namespace

def _check_device(xp, device):
if xp == sys.modules.get('numpy'):
if device not in ["cpu", None]:

def _check_device(bare_xp, device):
"""
Validate dummy device on device-less array backends.

Notes
-----
This function is also invoked by CuPy, which does have multiple devices
if there are multiple GPUs available.
However, CuPy multi-device support is currently impossible
without using the global device or a context manager:

https://github.com/data-apis/array-api-compat/pull/293
"""
if bare_xp is sys.modules.get('numpy'):
if device not in ("cpu", None):
raise ValueError(f"Unsupported device for NumPy: {device!r}")

elif bare_xp is sys.modules.get('dask.array'):
if device not in ("cpu", _DASK_DEVICE, None):
raise ValueError(f"Unsupported device for Dask: {device!r}")


# Placeholder object to represent the dask device
# when the array backend is not the CPU.
# (since it is not easy to tell which device a dask array is on)
Expand Down
5 changes: 4 additions & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
import dask.array as da

from ...common import _aliases, array_namespace
from ...common import _aliases, _helpers, array_namespace
from ...common._typing import (
Array,
Device,
Expand Down Expand Up @@ -56,6 +56,7 @@ def astype(
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)

if not copy and dtype == x.dtype:
return x
Expand Down Expand Up @@ -86,6 +87,7 @@ def arange(
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)

args = [start]
if stop is not None:
Expand Down Expand Up @@ -155,6 +157,7 @@ def asarray(
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)

if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
Expand Down
6 changes: 3 additions & 3 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

from .._internal import get_xp
from ..common import _aliases
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType
Expand Down Expand Up @@ -95,8 +95,7 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")
_helpers._check_device(np, device)

if hasattr(np, '_CopyMode'):
if copy is None:
Expand All @@ -122,6 +121,7 @@ def astype(
copy: bool = True,
device: Optional[Device] = None,
) -> Array:
_helpers._check_device(np, device)
return x.astype(dtype=dtype, copy=copy)


Expand Down
Loading