Skip to content

Commit 57d168e

Browse files
committed
[WIP] ENH: dask+cupy, dask+sparse etc. namespaces
1 parent b5a57eb commit 57d168e

File tree

4 files changed

+79
-6
lines changed

4 files changed

+79
-6
lines changed

array_api_compat/common/_helpers.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ def is_dask_namespace(xp: Namespace) -> bool:
352352
"""
353353
Returns True if `xp` is a Dask namespace.
354354
355-
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
355+
This includes ``dask.array`` itself, the version wrapped by array-api-compat,
356+
and the bespoke namespaces generated by
357+
``array_api_compat.dask.array.wrap_namespace``.
356358
357359
See Also
358360
--------
@@ -366,7 +368,11 @@ def is_dask_namespace(xp: Namespace) -> bool:
366368
is_pydata_sparse_namespace
367369
is_array_api_strict_namespace
368370
"""
369-
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
371+
da_compat_name = _compat_module_name() + '.dask.array'
372+
return (
373+
xp.__name__ in {'dask.array', da_compat_name}
374+
or xp.__name__.startswith(da_compat_name + '.')
375+
)
370376

371377

372378
def is_jax_namespace(xp: Namespace) -> bool:
@@ -543,8 +549,16 @@ def your_function(x, y):
543549
elif is_dask_array(x):
544550
if _use_compat:
545551
_check_api_version(api_version)
546-
from ..dask import array as dask_namespace
547-
namespaces.add(dask_namespace)
552+
from ..dask.array import wrap_namespace
553+
554+
# The meta-namespace is only used to generate the meta-array, so it
555+
# would be useless to create a namespace such as e.g.
556+
# array_api_compat.dask.array.array_api_compat.cupy.
557+
# It would get worse once you vendor array-api-compat!
558+
# So keep it clean with array_api_compat.dask.array.cupy.
559+
mxp = array_namespace(x._meta, use_compat=False)
560+
xp = wrap_namespace(mxp)
561+
namespaces.add(xp)
548562
else:
549563
import dask.array as da
550564
namespaces.add(da)

array_api_compat/dask/array/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# These imports may overwrite names from the import * above.
44
from ._aliases import * # noqa: F403
5+
from ._meta import wrap_namespace # noqa: F401
56

67
__array_api_version__ = '2024.12'
78

array_api_compat/dask/array/_aliases.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def asarray(
143143
dtype: Optional[DType] = None,
144144
device: Optional[Device] = None,
145145
copy: Optional[Union[bool, np._CopyMode]] = None,
146+
like: Optional[Array] = None,
146147
**kwargs,
147148
) -> Array:
148149
"""
@@ -158,7 +159,11 @@ def asarray(
158159
if copy is False:
159160
raise ValueError("Unable to avoid copy when changing dtype")
160161
obj = obj.astype(dtype)
161-
return obj.copy() if copy else obj
162+
if copy:
163+
obj = obj.copy()
164+
if like is not None:
165+
obj = da.asarray(obj, like=like)
166+
return obj
162167

163168
if copy is False:
164169
raise NotImplementedError(
@@ -167,7 +172,11 @@ def asarray(
167172

168173
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
169174
# see https://github.com/dask/dask/pull/11524/
170-
obj = np.array(obj, dtype=dtype, copy=True)
175+
if like is not None:
176+
mxp = array_namespace(like)
177+
obj = mxp.asarray(obj, dtype=dtype, copy=True)
178+
else:
179+
obj = np.array(obj, dtype=dtype, copy=True)
171180
return da.from_array(obj)
172181

173182

array_api_compat/dask/array/_meta.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import functools
2+
import sys
3+
import types
4+
5+
from ...common import is_numpy_namespace
6+
7+
__all__ = ['wrap_namespace']
8+
_all_ignore = ['functools', 'sys', 'types', 'is_numpy_namespace']
9+
10+
11+
def wrap_namespace(xp):
12+
"""Create a bespoke Dask namespace that wraps around another namespace.
13+
14+
Parameters
15+
----------
16+
xp : namespace
17+
Namespace to be wrapped by Dask
18+
19+
Returns
20+
-------
21+
namespace :
22+
A module object that duplicates array_api_compat.dask.array, with the
23+
difference that all creation functions will create an array with the same
24+
meta namespace as the input.
25+
"""
26+
from .. import array as da_compat
27+
28+
if is_numpy_namespace(xp):
29+
return da_compat
30+
31+
mod_name = f'{da_compat.__name__}.{xp.__name__}'
32+
try:
33+
return sys.modules[mod_name]
34+
except KeyError:
35+
pass
36+
37+
mod = types.ModuleType(mod_name)
38+
sys.modules[mod_name] = mod
39+
40+
meta = xp.empty(())
41+
for name, v in da_compat.__dict__.items():
42+
if name.startswith('_'):
43+
continue
44+
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
45+
'full', 'linspace', 'ones', 'zeros'}:
46+
v = functools.wraps(v)(functools.partial(v, like=meta))
47+
setattr(mod, name, v)
48+
49+
return mod

0 commit comments

Comments
 (0)