Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ENH: dask+cupy, dask+sparse etc. namespaces #270

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ def is_dask_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a Dask namespace.

This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
This includes ``dask.array`` itself, the version wrapped by array-api-compat,
and the bespoke namespaces generated by
``array_api_compat.dask.array.wrap_namespace``.

See Also
--------
Expand All @@ -366,7 +368,11 @@ def is_dask_namespace(xp: Namespace) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
da_compat_name = _compat_module_name() + '.dask.array'
return (
xp.__name__ in {'dask.array', da_compat_name}
or xp.__name__.startswith(da_compat_name + '.')
)


def is_jax_namespace(xp: Namespace) -> bool:
Expand Down Expand Up @@ -543,8 +549,16 @@ def your_function(x, y):
elif is_dask_array(x):
if _use_compat:
_check_api_version(api_version)
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
from ..dask.array import wrap_namespace

# The meta-namespace is only used to generate the meta-array, so it
# would be useless to create a namespace such as e.g.
# array_api_compat.dask.array.array_api_compat.cupy.
# It would get worse once you vendor array-api-compat!
# So keep it clean with array_api_compat.dask.array.cupy.
mxp = array_namespace(x._meta, use_compat=False)
xp = wrap_namespace(mxp)
namespaces.add(xp)
else:
import dask.array as da
namespaces.add(da)
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
from ._meta import wrap_namespace # noqa: F401

__array_api_version__ = '2024.12'

Expand Down
13 changes: 11 additions & 2 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def asarray(
dtype: Optional[DType] = None,
device: Optional[Device] = None,
copy: Optional[Union[bool, np._CopyMode]] = None,
like: Optional[Array] = None,
**kwargs,
) -> Array:
"""
Expand All @@ -161,7 +162,11 @@ def asarray(
if copy is False:
raise ValueError("Unable to avoid copy when changing dtype")
obj = obj.astype(dtype)
return obj.copy() if copy else obj
if copy:
obj = obj.copy()
if like is not None:
obj = da.asarray(obj, like=like)
return obj

if copy is False:
raise NotImplementedError(
Expand All @@ -170,7 +175,11 @@ def asarray(

# copy=None to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
if like is not None:
mxp = array_namespace(like)
obj = mxp.asarray(obj, dtype=dtype, copy=True)
else:
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)


Expand Down
50 changes: 50 additions & 0 deletions array_api_compat/dask/array/_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import functools
import sys
import types

from ...common._helpers import is_numpy_namespace
from ...common._typing import Namespace

__all__ = ['wrap_namespace']
_all_ignore = ['functools', 'sys', 'types', 'is_numpy_namespace']


def wrap_namespace(xp: Namespace) -> Namespace:
"""Create a bespoke Dask namespace that wraps around another namespace.

Parameters
----------
xp : namespace
Namespace to be wrapped by Dask

Returns
-------
namespace :
A module object that duplicates array_api_compat.dask.array, with the
difference that all creation functions will create an array with the same
meta namespace as the input.
"""
from .. import array as da_compat

if is_numpy_namespace(xp):
return da_compat

mod_name = f'{da_compat.__name__}.{xp.__name__}'
try:
return sys.modules[mod_name]
except KeyError:
pass

mod = types.ModuleType(mod_name)
sys.modules[mod_name] = mod

meta = xp.empty(())
for name, v in da_compat.__dict__.items():
if name.startswith('_'):
continue
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
'full', 'linspace', 'ones', 'zeros'}:
v = functools.wraps(v)(functools.partial(v, like=meta))
setattr(mod, name, v)

return mod
Loading