diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 67c619b8..ef5f755d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -352,7 +352,9 @@ def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. - This includes both ``dask.array`` itself and the version wrapped by array-api-compat. + This includes ``dask.array`` itself, the version wrapped by array-api-compat, + and the bespoke namespaces generated by + ``array_api_compat.dask.array.wrap_namespace``. See Also -------- @@ -366,7 +368,11 @@ def is_dask_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} + da_compat_name = _compat_module_name() + '.dask.array' + return ( + xp.__name__ in {'dask.array', da_compat_name} + or xp.__name__.startswith(da_compat_name + '.') + ) def is_jax_namespace(xp: Namespace) -> bool: @@ -543,8 +549,16 @@ def your_function(x, y): elif is_dask_array(x): if _use_compat: _check_api_version(api_version) - from ..dask import array as dask_namespace - namespaces.add(dask_namespace) + from ..dask.array import wrap_namespace + + # The meta-namespace is only used to generate the meta-array, so it + # would be useless to create a namespace such as e.g. + # array_api_compat.dask.array.array_api_compat.cupy. + # It would get worse once you vendor array-api-compat! + # So keep it clean with array_api_compat.dask.array.cupy. + mxp = array_namespace(x._meta, use_compat=False) + xp = wrap_namespace(mxp) + namespaces.add(xp) else: import dask.array as da namespaces.add(da) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index bb649306..7d67c0d1 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -2,6 +2,7 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._meta import wrap_namespace # noqa: F401 __array_api_version__ = '2024.12' diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 4733b1a6..eaba0141 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -148,6 +148,7 @@ def asarray( dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[Union[bool, np._CopyMode]] = None, + like: Optional[Array] = None, **kwargs, ) -> Array: """ @@ -164,7 +165,11 @@ def asarray( if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) - return obj.copy() if copy else obj + if copy: + obj = obj.copy() + if like is not None: + obj = da.asarray(obj, like=like) + return obj if copy is False: raise NotImplementedError( @@ -173,7 +178,11 @@ def asarray( # copy=None to be uniform across dask < 2024.12 and >= 2024.12 # see https://github.com/dask/dask/pull/11524/ - obj = np.array(obj, dtype=dtype, copy=True) + if like is not None: + mxp = array_namespace(like) + obj = mxp.asarray(obj, dtype=dtype, copy=True) + else: + obj = np.array(obj, dtype=dtype, copy=True) return da.from_array(obj) diff --git a/array_api_compat/dask/array/_meta.py b/array_api_compat/dask/array/_meta.py new file mode 100644 index 00000000..8bf20205 --- /dev/null +++ b/array_api_compat/dask/array/_meta.py @@ -0,0 +1,50 @@ +import functools +import sys +import types + +from ...common._helpers import is_numpy_namespace +from ...common._typing import Namespace + +__all__ = ['wrap_namespace'] +_all_ignore = ['functools', 'sys', 'types', 'is_numpy_namespace'] + + +def wrap_namespace(xp: Namespace) -> Namespace: + """Create a bespoke Dask namespace that wraps around another namespace. + + Parameters + ---------- + xp : namespace + Namespace to be wrapped by Dask + + Returns + ------- + namespace : + A module object that duplicates array_api_compat.dask.array, with the + difference that all creation functions will create an array with the same + meta namespace as the input. + """ + from .. import array as da_compat + + if is_numpy_namespace(xp): + return da_compat + + mod_name = f'{da_compat.__name__}.{xp.__name__}' + try: + return sys.modules[mod_name] + except KeyError: + pass + + mod = types.ModuleType(mod_name) + sys.modules[mod_name] = mod + + meta = xp.empty(()) + for name, v in da_compat.__dict__.items(): + if name.startswith('_'): + continue + if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack', + 'full', 'linspace', 'ones', 'zeros'}: + v = functools.wraps(v)(functools.partial(v, like=meta)) + setattr(mod, name, v) + + return mod