Skip to content

Commit fa81046

Browse files
committed
DNM ENH: More lazy functions
1 parent 70ddac8 commit fa81046

File tree

8 files changed

+365
-47
lines changed

8 files changed

+365
-47
lines changed

docs/api-reference.md

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
expand_dims
1414
kron
1515
lazy_apply
16+
lazy_raise
17+
lazy_wait_on
18+
lazy_warn
1619
nunique
1720
pad
1821
setdiff1d

pixi.lock

+43-43
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+8-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = ["array-api-compat>=1.10.0,<2"]
29+
# DNM
30+
dependencies = ["array-api-compat @ git+https://github.com/data-apis/array-api-compat.git@8a7999434452019c3110e06f6224fa71a023a549"]
31+
# dependencies = ["array-api-compat>=1.10.0,<2"]
3032

3133
[project.urls]
3234
Homepage = "https://github.com/data-apis/array-api-extra"
@@ -39,6 +41,9 @@ Changelog = "https://github.com/data-apis/array-api-extra/releases"
3941
[tool.hatch]
4042
version.path = "src/array_api_extra/__init__.py"
4143

44+
# DNM
45+
[tool.hatch.metadata]
46+
allow-direct-references = true # Enable git dependencies
4247

4348
# Pixi
4449

@@ -48,7 +53,8 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
4853

4954
[tool.pixi.dependencies]
5055
python = ">=3.10,<3.14"
51-
array-api-compat = ">=1.10.0,<2"
56+
# DNM
57+
# array-api-compat = ">=1.10.0,<2"
5258

5359
[tool.pixi.pypi-dependencies]
5460
array-api-extra = { path = ".", editable = true }

src/array_api_extra/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
setdiff1d,
1313
sinc,
1414
)
15-
from ._lib._lazy import lazy_apply
15+
from ._lib._lazy import lazy_apply, lazy_raise, lazy_wait_on, lazy_warn
1616

1717
__version__ = "0.6.1.dev0"
1818

@@ -26,6 +26,9 @@
2626
"expand_dims",
2727
"kron",
2828
"lazy_apply",
29+
"lazy_raise",
30+
"lazy_wait_on",
31+
"lazy_warn",
2932
"nunique",
3033
"pad",
3134
"setdiff1d",

src/array_api_extra/_lib/_lazy.py

+301-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@
44
from __future__ import annotations
55

66
import math
7+
import warnings
78
from collections.abc import Callable, Sequence
89
from functools import wraps
910
from types import ModuleType
1011
from typing import TYPE_CHECKING, Any, cast, overload
1112

12-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace
13+
from ._utils._compat import (
14+
array_namespace,
15+
is_dask_namespace,
16+
is_jax_namespace,
17+
is_lazy_array,
18+
)
1319
from ._utils._typing import Array, DType
1420

1521
if TYPE_CHECKING:
@@ -319,3 +325,297 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
319325
return (xp.asarray(out),)
320326

321327
return wrapper
328+
329+
330+
def lazy_raise( # numpydoc ignore=SA04
331+
x: Array,
332+
cond: bool | Array,
333+
exc: Exception,
334+
*,
335+
xp: ModuleType | None = None,
336+
) -> Array:
337+
"""
338+
Raise if an eager check fails on a lazy array.
339+
340+
Consider this snippet::
341+
342+
>>> def f(x, xp):
343+
... if xp.any(x < 0):
344+
... raise ValueError("Some points are negative")
345+
... return x + 1
346+
347+
The above code fails to compile when x is a JAX array and the function is wrapped
348+
by `jax.jit`; it is also extremely slow on Dask. Other lazy backends, e.g. ndonnx,
349+
are also expected to misbehave.
350+
351+
`xp.any(x < 0)` is a 0-dimensional array with `dtype=bool`; the `if` statement calls
352+
`bool()` on the Array to convert it to a Python bool.
353+
354+
On eager backends such as NumPy, this is not a problem. On Dask, `bool()` implicitly
355+
triggers a computation of the whole graph so far; what's worse is that the
356+
intermediate results are discarded to optimize memory usage, so when later on user
357+
explicitly calls `compute()` on their final output, `x` is recalculated from
358+
scratch. On JAX, `bool()` raises if its called code is wrapped by `jax.jit` for the
359+
same reason.
360+
361+
You should rewrite the above code as follows::
362+
363+
>>> def f(x, xp):
364+
... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
365+
... return x + 1
366+
367+
When `xp` is eager, this is equivalent to the original code; if the error condition
368+
resolves to True, the function raises immediately and the next line `return x + 1`
369+
is never executed.
370+
When `xp` is lazy, the function always returns a lazy array. When eventually the
371+
user actually computes it, e.g. in Dask by calling `compute()` and in JAX by having
372+
their outermost function decorated with `@jax.jit` return, only then the error
373+
condition is evaluated. If True, the exception is raised and propagated as normal,
374+
and the following nodes of the graph are never executed (so if the health check was
375+
in place to prevent not only incorrect results but e.g. a segmentation fault, it's
376+
still going to achieve its purpose).
377+
378+
Parameters
379+
----------
380+
x : Array
381+
Any one Array, potentially lazy, that is used later on to produce the value
382+
returned by your function.
383+
cond : bool | Array
384+
Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.
385+
If True, raise the exception. If False, return x.
386+
exc : Exception
387+
The exception instance to be raised.
388+
xp : array_namespace, optional
389+
The standard-compatible namespace for `x`. Default: infer.
390+
391+
Returns
392+
-------
393+
Array
394+
`x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered
395+
to raise `exc` if `cond` is True.
396+
397+
Raises
398+
------
399+
type(x)
400+
If `cond` evaluates to True.
401+
402+
Warnings
403+
--------
404+
This function raises when x is eager, and quietly skips the check
405+
when x is lazy::
406+
407+
>>> def f(x, xp):
408+
... lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
409+
... return x + 1
410+
411+
And so does this one, as lazy_raise replaces `x` but it does so too late to
412+
contribute to the return value::
413+
414+
>>> def f(x, xp):
415+
... y = x + 1
416+
... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
417+
... return y
418+
419+
See Also
420+
--------
421+
lazy_apply
422+
lazy_warn
423+
lazy_wait_on
424+
dask.graph_manipulation.wait_on
425+
426+
Notes
427+
-----
428+
This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is
429+
a JAX array on a non-CPU device.
430+
"""
431+
432+
def _lazy_raise(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01
433+
"""Eager helper of `lazy_raise` running inside the lazy graph."""
434+
if cond:
435+
raise exc
436+
return x
437+
438+
return _lazy_wait_on_impl(x, cond, _lazy_raise, xp=xp)
439+
440+
441+
# Signature of warnings.warn copied from python/typeshed
442+
@overload
443+
def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
444+
x: Array,
445+
cond: bool | Array,
446+
message: str,
447+
category: type[Warning] | None = None,
448+
stacklevel: int = 1,
449+
source: Any | None = None,
450+
*,
451+
xp: ModuleType | None = None,
452+
) -> None: ...
453+
@overload
454+
def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
455+
x: Array,
456+
cond: bool | Array,
457+
message: Warning,
458+
category: Any = None,
459+
stacklevel: int = 1,
460+
source: Any | None = None,
461+
*,
462+
xp: ModuleType | None = None,
463+
) -> None: ...
464+
465+
466+
def lazy_warn( # type: ignore[no-any-explicit] # numpydoc ignore=SA04,PR04
467+
x: Array,
468+
cond: bool | Array,
469+
message: str | Warning,
470+
category: Any = None,
471+
stacklevel: int = 1,
472+
source: Any | None = None,
473+
*,
474+
xp: ModuleType | None = None,
475+
) -> Array:
476+
"""
477+
Call `warnings.warn` if an eager check fails on a lazy array.
478+
479+
This functions works in the same way as `lazy_raise`; refer to it
480+
for the detailed explanation.
481+
482+
You should replace::
483+
484+
>>> def f(x, xp):
485+
... if xp.any(x < 0):
486+
... warnings.warn("Some points are negative", UserWarning, stacklevel=2)
487+
... return x + 1
488+
489+
with::
490+
491+
>>> def f(x, xp):
492+
... x = lazy_warn(x, xp.any(x < 0),
493+
... "Some points are negative", UserWarning, stacklevel=2)
494+
... return x + 1
495+
496+
Parameters
497+
----------
498+
x : Array
499+
Any one Array, potentially lazy, that is used later on to produce the value
500+
returned by your function.
501+
cond : bool | Array
502+
Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.
503+
If True, raise the exception. If False, return x.
504+
message, category, stacklevel, source :
505+
Parameters to `warnings.warn`. `stacklevel` is automatically increased to
506+
compensate for the extra wrapper function.
507+
xp : array_namespace, optional
508+
The standard-compatible namespace for `x`. Default: infer.
509+
510+
Returns
511+
-------
512+
Array
513+
`x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered
514+
to issue the warning if `cond` is True.
515+
516+
See Also
517+
--------
518+
warnings.warn
519+
lazy_apply
520+
lazy_raise
521+
lazy_wait_on
522+
dask.graph_manipulation.wait_on
523+
524+
Notes
525+
-----
526+
On Dask, the warning is typically going to appear on the log of the
527+
worker executing the function instead of on the client.
528+
"""
529+
530+
def _lazy_warn(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01
531+
"""Eager helper of `lazy_raise` running inside the lazy graph."""
532+
if cond:
533+
warnings.warn(message, category, stacklevel=stacklevel + 2, source=source)
534+
return x
535+
536+
return _lazy_wait_on_impl(x, cond, _lazy_warn, xp=xp)
537+
538+
539+
def lazy_wait_on(
540+
x: Array, wait_on: object, *, xp: ModuleType | None = None
541+
) -> Array: # numpydoc ignore=SA04
542+
"""
543+
Pause materialization of `x` until `wait_on` has been materialized.
544+
545+
This is typically used to collect multiple calls to `lazy_raise` and/or
546+
`lazy_warn` from validation functions that would otherwise return None.
547+
If `wait_on` is not a lazy array, just return `x`.
548+
549+
Read `lazy_raise` for detailed explanation.
550+
551+
Parameters
552+
----------
553+
x : Array
554+
Any one Array, potentially lazy, that is used later on to produce the value
555+
returned by your function.
556+
wait_on : object
557+
Any object. If it's a lazy array, block the materialization of `x` until
558+
`wait_on` has been fully materialized.
559+
xp : array_namespace, optional
560+
The standard-compatible namespace for `x`. Default: infer.
561+
562+
Returns
563+
-------
564+
Array
565+
`x`. If both `x` and `wait_on` are lazy arrays, the graph
566+
underlying `x` is altered to wait until `wait_on` has been materialized.
567+
If `wait_on` raises, the exception is propagated to `x`.
568+
569+
See Also
570+
--------
571+
lazy_apply
572+
lazy_raise
573+
lazy_warn
574+
dask.graph_manipulation.wait_on
575+
576+
Examples
577+
--------
578+
::
579+
580+
def validate(x, xp):
581+
# Future that evaluates the checks. Contents are inconsequential.
582+
# Avoid zero-sized arrays, as they may be elided by the graph optimizer.
583+
future = xp.empty(1)
584+
future = lazy_raise(future, xp.any(x < 10), ValueError("Less than 10"))
585+
future = lazy_warn(future, xp.any(x > 20), UserWarning, "More than 20"))
586+
return future
587+
588+
def f(x, xp):
589+
x = lazy_wait_on(x, validate(x, xp), xp=xp)
590+
return x + 1
591+
"""
592+
593+
def _lazy_wait_on(x: Array, _: Array) -> Array: # numpydoc ignore=PR01,RT01
594+
"""Eager helper of `lazy_wait_on` running inside the lazy graph."""
595+
return x
596+
597+
return _lazy_wait_on_impl(x, wait_on, _lazy_wait_on, xp=xp)
598+
599+
600+
def _lazy_wait_on_impl( # numpydoc ignore=PR01,RT01
601+
x: Array,
602+
wait_on: object,
603+
eager_func: Callable[[Array, Array], Array],
604+
xp: ModuleType | None,
605+
) -> Array:
606+
"""Implementation of lazy_raise, lazy_warn, and lazy_wait_on."""
607+
if not is_lazy_array(wait_on):
608+
return eager_func(x, wait_on)
609+
610+
if cast(Array, wait_on).shape != ():
611+
msg = "cond/wait_on must be 0-dimensional"
612+
raise ValueError(msg)
613+
614+
if xp is None:
615+
xp = array_namespace(x, wait_on)
616+
617+
if is_dask_namespace(xp):
618+
# lazy_apply would rechunk x
619+
return xp.map_blocks(eager_func, x, wait_on, dtype=x.dtype, meta=x._meta) # pylint: disable=protected-access
620+
621+
return lazy_apply(eager_func, x, wait_on, shape=x.shape, dtype=x.dtype, xp=xp)

0 commit comments

Comments
 (0)