3
3
from __future__ import annotations
4
4
5
5
import math
6
+ import operator
6
7
from collections .abc import Callable , Sequence
7
8
from functools import partial , wraps
8
9
from types import ModuleType
9
- from typing import TYPE_CHECKING , Any , ParamSpec , TypeAlias , cast , overload
10
+ from typing import TYPE_CHECKING , Any , TypeAlias , cast , overload
10
11
11
12
from ._funcs import broadcast_shapes
12
13
from ._utils import _compat
27
28
# Sphinx hack
28
29
NumPyObject = Any
29
30
30
- P = ParamSpec ("P" )
31
-
32
31
33
32
@overload
34
- def lazy_apply ( # type: ignore[decorated -any, valid-type ]
35
- func : Callable [P , Array | ArrayLike ],
33
+ def lazy_apply ( # type: ignore[explicit -any,decorated-any ]
34
+ func : Callable [... , Array | ArrayLike ],
36
35
* args : Array | complex | None ,
37
36
shape : tuple [int | None , ...] | None = None ,
38
37
dtype : DType | None = None ,
39
38
as_numpy : bool = False ,
40
39
xp : ModuleType | None = None ,
41
- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
40
+ ** kwargs : Any ,
42
41
) -> Array : ... # numpydoc ignore=GL08
43
42
44
43
45
44
@overload
46
- def lazy_apply ( # type: ignore[decorated -any, valid-type ]
47
- func : Callable [P , Sequence [Array | ArrayLike ]],
45
+ def lazy_apply ( # type: ignore[explicit -any,decorated-any ]
46
+ func : Callable [... , Sequence [Array | ArrayLike ]],
48
47
* args : Array | complex | None ,
49
48
shape : Sequence [tuple [int | None , ...]],
50
49
dtype : Sequence [DType ] | None = None ,
51
50
as_numpy : bool = False ,
52
51
xp : ModuleType | None = None ,
53
- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
52
+ ** kwargs : Any ,
54
53
) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
55
54
56
55
57
- def lazy_apply ( # type: ignore[valid-type ] # numpydoc ignore=GL07,SA04
58
- func : Callable [P , Array | ArrayLike | Sequence [Array | ArrayLike ]],
56
+ def lazy_apply ( # type: ignore[explicit-any ] # numpydoc ignore=GL07,SA04
57
+ func : Callable [... , Array | ArrayLike | Sequence [Array | ArrayLike ]],
59
58
* args : Array | complex | None ,
60
59
shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
61
60
dtype : DType | Sequence [DType ] | None = None ,
62
61
as_numpy : bool = False ,
63
62
xp : ModuleType | None = None ,
64
- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
63
+ ** kwargs : Any ,
65
64
) -> Array | tuple [Array , ...]:
66
65
"""
67
66
Lazily apply an eager function.
@@ -162,10 +161,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
162
161
The outputs will also be returned as a single chunk and you should consider
163
162
rechunking them into smaller chunks afterwards.
164
163
165
- If you want to distribute the calculation across multiple workers, you
166
- should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
167
- :func:`dask.array.blockwise`, or a native Dask wrapper instead of
168
- `lazy_apply`.
164
+ If you want to distribute the calculation across multiple workers and your
165
+ function is elementwise, you should use :func:`lazy_apply_elementwise` instead.
166
+ If the function is not elementwise, you should consider writing an ad-hoc
167
+ variant for Dask using primitives like :func:`dask.array.blockwise`,
168
+ :func:`dask.array.map_overlap`, or a native Dask algorithm.
169
169
170
170
Dask wrapping around other backends
171
171
If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
@@ -186,9 +186,9 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
186
186
187
187
See Also
188
188
--------
189
+ lazy_apply_elementwise
189
190
jax.transfer_guard
190
191
jax.pure_callback
191
- dask.array.map_blocks
192
192
dask.array.map_overlap
193
193
dask.array.blockwise
194
194
"""
@@ -240,7 +240,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
240
240
if is_dask_namespace (xp ):
241
241
import dask
242
242
243
- metas : list [Array ] = [arg ._meta for arg in array_args ] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
243
+ metas : list [Array ] = [arg ._meta for arg in array_args ] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
244
244
meta_xp = array_namespace (* metas )
245
245
246
246
wrapped = dask .delayed ( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
@@ -355,3 +355,145 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
355
355
return (xp .asarray (out , device = device ),)
356
356
357
357
return wrapper
358
+
359
+
360
+ @overload
361
+ def lazy_apply_elementwise ( # type: ignore[explicit-any,decorated-any]
362
+ func : Callable [..., Array | ArrayLike ],
363
+ * args : Array | complex | None ,
364
+ dtype : DType | None = None ,
365
+ as_numpy : bool = False ,
366
+ xp : ModuleType | None = None ,
367
+ ** kwargs : Any ,
368
+ ) -> Array : ... # numpydoc ignore=GL08
369
+
370
+
371
+ @overload
372
+ def lazy_apply_elementwise ( # type: ignore[explicit-any,decorated-any]
373
+ func : Callable [..., Sequence [Array | ArrayLike ]],
374
+ * args : Array | complex | None ,
375
+ dtype : Sequence [DType | None ],
376
+ as_numpy : bool = False ,
377
+ xp : ModuleType | None = None ,
378
+ ** kwargs : Any ,
379
+ ) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
380
+
381
+
382
+ def lazy_apply_elementwise ( # type: ignore[explicit-any]
383
+ func : Callable [..., Array | ArrayLike | Sequence [Array | ArrayLike ]],
384
+ * args : Array | complex | None ,
385
+ dtype : DType | Sequence [DType | None ] | None = None ,
386
+ as_numpy : bool = False ,
387
+ xp : ModuleType | None = None ,
388
+ ** kwargs : Any ,
389
+ ) -> Array | tuple [Array , ...]:
390
+ """
391
+ Lazily apply an eager elementwise function.
392
+
393
+ This is a variant of :func:`lazy_apply` which expects `func` to be elementwise, e.g.
394
+ each output point must depend exclusively from the corresponding input point in each
395
+ inputarray. This can result in faster execution on some backends.
396
+
397
+ Parameters
398
+ ----------
399
+ func : callable
400
+ As in `lazy_apply`, but in addition it must be elementwise.
401
+ *args : Array | int | float | complex | bool | None
402
+ As in `lazy_apply`.
403
+ dtype : DType | Sequence[DType | None], optional
404
+ Output dtype or sequence of output dtypes, one for each output of `func`.
405
+ dtype(s) must belong to the same array namespace as the input arrays.
406
+ This also informs how many outputs the function has.
407
+ Default: assume a single output and infer the result type(s) from
408
+ the input arrays.
409
+ as_numpy : bool, optional
410
+ As in `lazy_apply`.
411
+ xp : array_namespace, optional
412
+ The standard-compatible namespace for `args`. Default: infer.
413
+ **kwargs : Any, optional
414
+ As in `lazy_apply`.
415
+
416
+ Returns
417
+ -------
418
+ Array | tuple[Array, ...]
419
+ The result(s) of `func` applied to the input arrays, wrapped in the same
420
+ array namespace as the inputs.
421
+ If dtype is omitted or a single dtype, return a single array.
422
+ Otherwise, return a tuple of arrays.
423
+
424
+ See Also
425
+ --------
426
+ lazy_apply : General version of this function.
427
+ dask.array.map_blocks : Dask version of this function.
428
+
429
+ Notes
430
+ -----
431
+ Unlike in :func:`lazy_apply`, you can't define output shapes that aren't
432
+ broadcasted from the input arrays.
433
+
434
+ Dask
435
+ Unlike :func:`dask.array.map_blocks`, this function allows for multiple outputs.
436
+
437
+ Dask wrapping around other backends
438
+ If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
439
+ namespace, as defined by the ``._meta`` attribute of the input Dask arrays. The
440
+ outputs of `func` will be wrapped by the meta namespace, and then wrapped again
441
+ by Dask.
442
+
443
+ All other backends
444
+ This function is identical to :func:`lazy_apply`.
445
+ """
446
+ args_not_none = [arg for arg in args if arg is not None ]
447
+ array_args = [arg for arg in args_not_none if not is_python_scalar (arg )]
448
+ if not array_args :
449
+ msg = "Must have at least one argument array"
450
+ raise ValueError (msg )
451
+ if xp is None :
452
+ xp = array_namespace (* array_args )
453
+
454
+ # Normalize and validate dtype
455
+ dtypes : list [DType ]
456
+
457
+ if isinstance (dtype , Sequence ):
458
+ multi_output = True
459
+ if None in dtype :
460
+ rtype = xp .result_type (* args_not_none )
461
+ dtypes = [d or rtype for d in dtype ]
462
+ else :
463
+ dtypes = list (dtype ) # pyright: ignore[reportUnknownArgumentType]
464
+ else :
465
+ multi_output = False
466
+ dtypes = [dtype ]
467
+ del dtype
468
+
469
+ if not is_dask_namespace (xp ):
470
+ shape = broadcast_shapes (* (arg .shape for arg in array_args ))
471
+ return lazy_apply ( # pyright: ignore[reportCallIssue]
472
+ func , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
473
+ * args ,
474
+ shape = [shape ] * len (dtypes ) if multi_output else shape , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
475
+ dtype = dtypes if multi_output else dtypes [0 ],
476
+ as_numpy = as_numpy ,
477
+ xp = xp ,
478
+ ** kwargs ,
479
+ )
480
+
481
+ # Use da.map_blocks.
482
+ # We need to handle multiple outputs, which map_blocks can't.
483
+
484
+ metas : list [Array ] = [arg ._meta for arg in array_args ] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
485
+ meta_xp = array_namespace (* metas )
486
+
487
+ wrapped = _lazy_apply_wrapper (func , as_numpy , multi_output , meta_xp )
488
+ wrapped = partial (wrapped , ** kwargs )
489
+
490
+ # Hack map_blocks to handle multiple outputs. This intermediate output has bugos
491
+ # dtype and meta, but dask.array will never know as long as we always provide
492
+ # explicit dtype and meta.
493
+ temp = xp .map_blocks (wrapped , * args , dtype = dtypes [0 ], meta = metas [0 ])
494
+ out = tuple (
495
+ temp .map_blocks (operator .itemgetter (i ), dtype = dtype , meta = metas [0 ])
496
+ for i , dtype in enumerate (dtypes )
497
+ )
498
+
499
+ return out if multi_output else out [0 ]
0 commit comments