Skip to content

Commit 16f31f8

Browse files
committed
WIP lazy_apply_elementwise
1 parent bb6129b commit 16f31f8

File tree

4 files changed

+174
-23
lines changed

4 files changed

+174
-23
lines changed

docs/api-lazy.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ lazy backends, e.g. Dask or JAX:
1010
:toctree: generated
1111
1212
lazy_apply
13+
lazy_apply_elementwise
1314
testing.lazy_xp_function
1415
testing.patch_lazy_xp_functions
1516
```

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
setdiff1d,
1515
sinc,
1616
)
17-
from ._lib._lazy import lazy_apply
17+
from ._lib._lazy import lazy_apply, lazy_apply_elementwise
1818

1919
__version__ = "0.7.2.dev0"
2020

@@ -31,6 +31,7 @@
3131
"isclose",
3232
"kron",
3333
"lazy_apply",
34+
"lazy_apply_elementwise",
3435
"nunique",
3536
"pad",
3637
"setdiff1d",

src/array_api_extra/_lib/_lazy.py

Lines changed: 160 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from __future__ import annotations
44

55
import math
6+
import operator
67
from collections.abc import Callable, Sequence
78
from functools import partial, wraps
89
from types import ModuleType
9-
from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias, cast, overload
10+
from typing import TYPE_CHECKING, Any, TypeAlias, cast, overload
1011

1112
from ._funcs import broadcast_shapes
1213
from ._utils import _compat
@@ -27,41 +28,39 @@
2728
# Sphinx hack
2829
NumPyObject = Any
2930

30-
P = ParamSpec("P")
31-
3231

3332
@overload
34-
def lazy_apply( # type: ignore[decorated-any, valid-type]
35-
func: Callable[P, Array | ArrayLike],
33+
def lazy_apply( # type: ignore[explicit-any,decorated-any]
34+
func: Callable[..., Array | ArrayLike],
3635
*args: Array | complex | None,
3736
shape: tuple[int | None, ...] | None = None,
3837
dtype: DType | None = None,
3938
as_numpy: bool = False,
4039
xp: ModuleType | None = None,
41-
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
40+
**kwargs: Any,
4241
) -> Array: ... # numpydoc ignore=GL08
4342

4443

4544
@overload
46-
def lazy_apply( # type: ignore[decorated-any, valid-type]
47-
func: Callable[P, Sequence[Array | ArrayLike]],
45+
def lazy_apply( # type: ignore[explicit-any,decorated-any]
46+
func: Callable[..., Sequence[Array | ArrayLike]],
4847
*args: Array | complex | None,
4948
shape: Sequence[tuple[int | None, ...]],
5049
dtype: Sequence[DType] | None = None,
5150
as_numpy: bool = False,
5251
xp: ModuleType | None = None,
53-
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
52+
**kwargs: Any,
5453
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
5554

5655

57-
def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
58-
func: Callable[P, Array | ArrayLike | Sequence[Array | ArrayLike]],
56+
def lazy_apply( # type: ignore[explicit-any] # numpydoc ignore=GL07,SA04
57+
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
5958
*args: Array | complex | None,
6059
shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None,
6160
dtype: DType | Sequence[DType] | None = None,
6261
as_numpy: bool = False,
6362
xp: ModuleType | None = None,
64-
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
63+
**kwargs: Any,
6564
) -> Array | tuple[Array, ...]:
6665
"""
6766
Lazily apply an eager function.
@@ -162,10 +161,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
162161
The outputs will also be returned as a single chunk and you should consider
163162
rechunking them into smaller chunks afterwards.
164163
165-
If you want to distribute the calculation across multiple workers, you
166-
should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
167-
:func:`dask.array.blockwise`, or a native Dask wrapper instead of
168-
`lazy_apply`.
164+
If you want to distribute the calculation across multiple workers and your
165+
function is elementwise, you should use :func:`lazy_apply_elementwise` instead.
166+
If the function is not elementwise, you should consider writing an ad-hoc
167+
variant for Dask using primitives like :func:`dask.array.blockwise`,
168+
:func:`dask.array.map_overlap`, or a native Dask algorithm.
169169
170170
Dask wrapping around other backends
171171
If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
@@ -186,9 +186,9 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
186186
187187
See Also
188188
--------
189+
lazy_apply_elementwise
189190
jax.transfer_guard
190191
jax.pure_callback
191-
dask.array.map_blocks
192192
dask.array.map_overlap
193193
dask.array.blockwise
194194
"""
@@ -240,7 +240,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
240240
if is_dask_namespace(xp):
241241
import dask
242242

243-
metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
243+
metas: list[Array] = [arg._meta for arg in array_args] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
244244
meta_xp = array_namespace(*metas)
245245

246246
wrapped = dask.delayed( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
@@ -355,3 +355,145 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
355355
return (xp.asarray(out, device=device),)
356356

357357
return wrapper
358+
359+
360+
@overload
361+
def lazy_apply_elementwise( # type: ignore[explicit-any,decorated-any]
362+
func: Callable[..., Array | ArrayLike],
363+
*args: Array | complex | None,
364+
dtype: DType | None = None,
365+
as_numpy: bool = False,
366+
xp: ModuleType | None = None,
367+
**kwargs: Any,
368+
) -> Array: ... # numpydoc ignore=GL08
369+
370+
371+
@overload
372+
def lazy_apply_elementwise( # type: ignore[explicit-any,decorated-any]
373+
func: Callable[..., Sequence[Array | ArrayLike]],
374+
*args: Array | complex | None,
375+
dtype: Sequence[DType | None],
376+
as_numpy: bool = False,
377+
xp: ModuleType | None = None,
378+
**kwargs: Any,
379+
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
380+
381+
382+
def lazy_apply_elementwise( # type: ignore[explicit-any]
383+
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
384+
*args: Array | complex | None,
385+
dtype: DType | Sequence[DType | None] | None = None,
386+
as_numpy: bool = False,
387+
xp: ModuleType | None = None,
388+
**kwargs: Any,
389+
) -> Array | tuple[Array, ...]:
390+
"""
391+
Lazily apply an eager elementwise function.
392+
393+
This is a variant of :func:`lazy_apply` which expects `func` to be elementwise, e.g.
394+
each output point must depend exclusively from the corresponding input point in each
395+
inputarray. This can result in faster execution on some backends.
396+
397+
Parameters
398+
----------
399+
func : callable
400+
As in `lazy_apply`, but in addition it must be elementwise.
401+
*args : Array | int | float | complex | bool | None
402+
As in `lazy_apply`.
403+
dtype : DType | Sequence[DType | None], optional
404+
Output dtype or sequence of output dtypes, one for each output of `func`.
405+
dtype(s) must belong to the same array namespace as the input arrays.
406+
This also informs how many outputs the function has.
407+
Default: assume a single output and infer the result type(s) from
408+
the input arrays.
409+
as_numpy : bool, optional
410+
As in `lazy_apply`.
411+
xp : array_namespace, optional
412+
The standard-compatible namespace for `args`. Default: infer.
413+
**kwargs : Any, optional
414+
As in `lazy_apply`.
415+
416+
Returns
417+
-------
418+
Array | tuple[Array, ...]
419+
The result(s) of `func` applied to the input arrays, wrapped in the same
420+
array namespace as the inputs.
421+
If dtype is omitted or a single dtype, return a single array.
422+
Otherwise, return a tuple of arrays.
423+
424+
See Also
425+
--------
426+
lazy_apply : General version of this function.
427+
dask.array.map_blocks : Dask version of this function.
428+
429+
Notes
430+
-----
431+
Unlike in :func:`lazy_apply`, you can't define output shapes that aren't
432+
broadcasted from the input arrays.
433+
434+
Dask
435+
Unlike :func:`dask.array.map_blocks`, this function allows for multiple outputs.
436+
437+
Dask wrapping around other backends
438+
If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
439+
namespace, as defined by the ``._meta`` attribute of the input Dask arrays. The
440+
outputs of `func` will be wrapped by the meta namespace, and then wrapped again
441+
by Dask.
442+
443+
All other backends
444+
This function is identical to :func:`lazy_apply`.
445+
"""
446+
args_not_none = [arg for arg in args if arg is not None]
447+
array_args = [arg for arg in args_not_none if not is_python_scalar(arg)]
448+
if not array_args:
449+
msg = "Must have at least one argument array"
450+
raise ValueError(msg)
451+
if xp is None:
452+
xp = array_namespace(*array_args)
453+
454+
# Normalize and validate dtype
455+
dtypes: list[DType]
456+
457+
if isinstance(dtype, Sequence):
458+
multi_output = True
459+
if None in dtype:
460+
rtype = xp.result_type(*args_not_none)
461+
dtypes = [d or rtype for d in dtype]
462+
else:
463+
dtypes = list(dtype) # pyright: ignore[reportUnknownArgumentType]
464+
else:
465+
multi_output = False
466+
dtypes = [dtype]
467+
del dtype
468+
469+
if not is_dask_namespace(xp):
470+
shape = broadcast_shapes(*(arg.shape for arg in array_args))
471+
return lazy_apply( # pyright: ignore[reportCallIssue]
472+
func, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
473+
*args,
474+
shape=[shape] * len(dtypes) if multi_output else shape, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
475+
dtype=dtypes if multi_output else dtypes[0],
476+
as_numpy=as_numpy,
477+
xp=xp,
478+
**kwargs,
479+
)
480+
481+
# Use da.map_blocks.
482+
# We need to handle multiple outputs, which map_blocks can't.
483+
484+
metas: list[Array] = [arg._meta for arg in array_args] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
485+
meta_xp = array_namespace(*metas)
486+
487+
wrapped = _lazy_apply_wrapper(func, as_numpy, multi_output, meta_xp)
488+
wrapped = partial(wrapped, **kwargs)
489+
490+
# Hack map_blocks to handle multiple outputs. This intermediate output has bugos
491+
# dtype and meta, but dask.array will never know as long as we always provide
492+
# explicit dtype and meta.
493+
temp = xp.map_blocks(wrapped, *args, dtype=dtypes[0], meta=metas[0])
494+
out = tuple(
495+
temp.map_blocks(operator.itemgetter(i), dtype=dtype, meta=metas[0])
496+
for i, dtype in enumerate(dtypes)
497+
)
498+
499+
return out if multi_output else out[0]

tests/test_lazy.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
import array_api_extra as xpx # Let some tests bypass lazy_xp_function
9-
from array_api_extra import lazy_apply
9+
from array_api_extra import lazy_apply, lazy_apply_elementwise
1010
from array_api_extra._lib._backends import Backend
1111
from array_api_extra._lib._testing import xp_assert_equal
1212
from array_api_extra._lib._utils import _compat
@@ -371,7 +371,7 @@ def eager(
371371
return x + 1
372372

373373
# Use explicit namespace to bypass monkey-patching by lazy_xp_function
374-
return xpx.lazy_apply( # pyright: ignore[reportCallIssue]
374+
return xpx.lazy_apply(
375375
eager,
376376
x,
377377
z={0: [1, 2]},
@@ -448,6 +448,13 @@ def f(x: Array) -> Array:
448448
with pytest.raises(ValueError, match="multiple shapes but only one dtype"):
449449
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=np.int32) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
450450
with pytest.raises(ValueError, match="single shape but multiple dtypes"):
451-
_ = lazy_apply(f, x, shape=(1,), dtype=[np.int32, np.int64]) # pyright: ignore[reportCallIssue,reportArgumentType]
451+
_ = lazy_apply(f, x, shape=(1,), dtype=[np.int32, np.int64]) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
452452
with pytest.raises(ValueError, match="2 shapes and 1 dtypes"):
453-
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=[np.int32]) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
453+
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=[np.int32]) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
454+
455+
with pytest.raises(ValueError, match="at least one argument array"):
456+
_ = lazy_apply_elementwise(f, xp=np)
457+
with pytest.raises(ValueError, match="at least one argument array"):
458+
_ = lazy_apply_elementwise(f, 1, xp=np)
459+
with pytest.raises(ValueError, match="at least one argument array"):
460+
_ = lazy_apply_elementwise(f)

0 commit comments

Comments
 (0)