|
4 | 4 | from __future__ import annotations
|
5 | 5 |
|
6 | 6 | import math
|
| 7 | +import warnings |
7 | 8 | from collections.abc import Callable, Sequence
|
8 | 9 | from functools import wraps
|
9 | 10 | from types import ModuleType
|
10 | 11 | from typing import TYPE_CHECKING, Any, cast, overload
|
11 | 12 |
|
12 |
| -from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace |
| 13 | +from ._utils._compat import ( |
| 14 | + array_namespace, |
| 15 | + is_dask_namespace, |
| 16 | + is_jax_namespace, |
| 17 | + is_lazy_array, |
| 18 | +) |
13 | 19 | from ._utils._typing import Array, DType
|
14 | 20 |
|
15 | 21 | if TYPE_CHECKING:
|
@@ -319,3 +325,297 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
|
319 | 325 | return (xp.asarray(out),)
|
320 | 326 |
|
321 | 327 | return wrapper
|
| 328 | + |
| 329 | + |
| 330 | +def lazy_raise( # numpydoc ignore=SA04 |
| 331 | + x: Array, |
| 332 | + cond: bool | Array, |
| 333 | + exc: Exception, |
| 334 | + *, |
| 335 | + xp: ModuleType | None = None, |
| 336 | +) -> Array: |
| 337 | + """ |
| 338 | + Raise if an eager check fails on a lazy array. |
| 339 | +
|
| 340 | + Consider this snippet:: |
| 341 | +
|
| 342 | + >>> def f(x, xp): |
| 343 | + ... if xp.any(x < 0): |
| 344 | + ... raise ValueError("Some points are negative") |
| 345 | + ... return x + 1 |
| 346 | +
|
| 347 | + The above code fails to compile when x is a JAX array and the function is wrapped |
| 348 | + by `jax.jit`; it is also extremely slow on Dask. Other lazy backends, e.g. ndonnx, |
| 349 | + are also expected to misbehave. |
| 350 | +
|
| 351 | + `xp.any(x < 0)` is a 0-dimensional array with `dtype=bool`; the `if` statement calls |
| 352 | + `bool()` on the Array to convert it to a Python bool. |
| 353 | +
|
| 354 | + On eager backends such as NumPy, this is not a problem. On Dask, `bool()` implicitly |
| 355 | + triggers a computation of the whole graph so far; what's worse is that the |
| 356 | + intermediate results are discarded to optimize memory usage, so when later on user |
| 357 | + explicitly calls `compute()` on their final output, `x` is recalculated from |
| 358 | + scratch. On JAX, `bool()` raises if its called code is wrapped by `jax.jit` for the |
| 359 | + same reason. |
| 360 | +
|
| 361 | + You should rewrite the above code as follows:: |
| 362 | +
|
| 363 | + >>> def f(x, xp): |
| 364 | + ... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative")) |
| 365 | + ... return x + 1 |
| 366 | +
|
| 367 | + When `xp` is eager, this is equivalent to the original code; if the error condition |
| 368 | + resolves to True, the function raises immediately and the next line `return x + 1` |
| 369 | + is never executed. |
| 370 | + When `xp` is lazy, the function always returns a lazy array. When eventually the |
| 371 | + user actually computes it, e.g. in Dask by calling `compute()` and in JAX by having |
| 372 | + their outermost function decorated with `@jax.jit` return, only then the error |
| 373 | + condition is evaluated. If True, the exception is raised and propagated as normal, |
| 374 | + and the following nodes of the graph are never executed (so if the health check was |
| 375 | + in place to prevent not only incorrect results but e.g. a segmentation fault, it's |
| 376 | + still going to achieve its purpose). |
| 377 | +
|
| 378 | + Parameters |
| 379 | + ---------- |
| 380 | + x : Array |
| 381 | + Any one Array, potentially lazy, that is used later on to produce the value |
| 382 | + returned by your function. |
| 383 | + cond : bool | Array |
| 384 | + Must be either a plain Python bool or a 0-dimensional Array with boolean dtype. |
| 385 | + If True, raise the exception. If False, return x. |
| 386 | + exc : Exception |
| 387 | + The exception instance to be raised. |
| 388 | + xp : array_namespace, optional |
| 389 | + The standard-compatible namespace for `x`. Default: infer. |
| 390 | +
|
| 391 | + Returns |
| 392 | + ------- |
| 393 | + Array |
| 394 | + `x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered |
| 395 | + to raise `exc` if `cond` is True. |
| 396 | +
|
| 397 | + Raises |
| 398 | + ------ |
| 399 | + type(x) |
| 400 | + If `cond` evaluates to True. |
| 401 | +
|
| 402 | + Warnings |
| 403 | + -------- |
| 404 | + This function raises when x is eager, and quietly skips the check |
| 405 | + when x is lazy:: |
| 406 | +
|
| 407 | + >>> def f(x, xp): |
| 408 | + ... lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative")) |
| 409 | + ... return x + 1 |
| 410 | +
|
| 411 | + And so does this one, as lazy_raise replaces `x` but it does so too late to |
| 412 | + contribute to the return value:: |
| 413 | +
|
| 414 | + >>> def f(x, xp): |
| 415 | + ... y = x + 1 |
| 416 | + ... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative")) |
| 417 | + ... return y |
| 418 | +
|
| 419 | + See Also |
| 420 | + -------- |
| 421 | + lazy_apply |
| 422 | + lazy_warn |
| 423 | + lazy_wait_on |
| 424 | + dask.graph_manipulation.wait_on |
| 425 | +
|
| 426 | + Notes |
| 427 | + ----- |
| 428 | + This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is |
| 429 | + a JAX array on a non-CPU device. |
| 430 | + """ |
| 431 | + |
| 432 | + def _lazy_raise(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01 |
| 433 | + """Eager helper of `lazy_raise` running inside the lazy graph.""" |
| 434 | + if cond: |
| 435 | + raise exc |
| 436 | + return x |
| 437 | + |
| 438 | + return _lazy_wait_on_impl(x, cond, _lazy_raise, xp=xp) |
| 439 | + |
| 440 | + |
| 441 | +# Signature of warnings.warn copied from python/typeshed |
| 442 | +@overload |
| 443 | +def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 |
| 444 | + x: Array, |
| 445 | + cond: bool | Array, |
| 446 | + message: str, |
| 447 | + category: type[Warning] | None = None, |
| 448 | + stacklevel: int = 1, |
| 449 | + source: Any | None = None, |
| 450 | + *, |
| 451 | + xp: ModuleType | None = None, |
| 452 | +) -> None: ... |
| 453 | +@overload |
| 454 | +def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 |
| 455 | + x: Array, |
| 456 | + cond: bool | Array, |
| 457 | + message: Warning, |
| 458 | + category: Any = None, |
| 459 | + stacklevel: int = 1, |
| 460 | + source: Any | None = None, |
| 461 | + *, |
| 462 | + xp: ModuleType | None = None, |
| 463 | +) -> None: ... |
| 464 | + |
| 465 | + |
| 466 | +def lazy_warn( # type: ignore[no-any-explicit] # numpydoc ignore=SA04,PR04 |
| 467 | + x: Array, |
| 468 | + cond: bool | Array, |
| 469 | + message: str | Warning, |
| 470 | + category: Any = None, |
| 471 | + stacklevel: int = 1, |
| 472 | + source: Any | None = None, |
| 473 | + *, |
| 474 | + xp: ModuleType | None = None, |
| 475 | +) -> Array: |
| 476 | + """ |
| 477 | + Call `warnings.warn` if an eager check fails on a lazy array. |
| 478 | +
|
| 479 | + This functions works in the same way as `lazy_raise`; refer to it |
| 480 | + for the detailed explanation. |
| 481 | +
|
| 482 | + You should replace:: |
| 483 | +
|
| 484 | + >>> def f(x, xp): |
| 485 | + ... if xp.any(x < 0): |
| 486 | + ... warnings.warn("Some points are negative", UserWarning, stacklevel=2) |
| 487 | + ... return x + 1 |
| 488 | +
|
| 489 | + with:: |
| 490 | +
|
| 491 | + >>> def f(x, xp): |
| 492 | + ... x = lazy_warn(x, xp.any(x < 0), |
| 493 | + ... "Some points are negative", UserWarning, stacklevel=2) |
| 494 | + ... return x + 1 |
| 495 | +
|
| 496 | + Parameters |
| 497 | + ---------- |
| 498 | + x : Array |
| 499 | + Any one Array, potentially lazy, that is used later on to produce the value |
| 500 | + returned by your function. |
| 501 | + cond : bool | Array |
| 502 | + Must be either a plain Python bool or a 0-dimensional Array with boolean dtype. |
| 503 | + If True, raise the exception. If False, return x. |
| 504 | + message, category, stacklevel, source : |
| 505 | + Parameters to `warnings.warn`. `stacklevel` is automatically increased to |
| 506 | + compensate for the extra wrapper function. |
| 507 | + xp : array_namespace, optional |
| 508 | + The standard-compatible namespace for `x`. Default: infer. |
| 509 | +
|
| 510 | + Returns |
| 511 | + ------- |
| 512 | + Array |
| 513 | + `x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered |
| 514 | + to issue the warning if `cond` is True. |
| 515 | +
|
| 516 | + See Also |
| 517 | + -------- |
| 518 | + warnings.warn |
| 519 | + lazy_apply |
| 520 | + lazy_raise |
| 521 | + lazy_wait_on |
| 522 | + dask.graph_manipulation.wait_on |
| 523 | +
|
| 524 | + Notes |
| 525 | + ----- |
| 526 | + On Dask, the warning is typically going to appear on the log of the |
| 527 | + worker executing the function instead of on the client. |
| 528 | + """ |
| 529 | + |
| 530 | + def _lazy_warn(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01 |
| 531 | + """Eager helper of `lazy_raise` running inside the lazy graph.""" |
| 532 | + if cond: |
| 533 | + warnings.warn(message, category, stacklevel=stacklevel + 2, source=source) |
| 534 | + return x |
| 535 | + |
| 536 | + return _lazy_wait_on_impl(x, cond, _lazy_warn, xp=xp) |
| 537 | + |
| 538 | + |
| 539 | +def lazy_wait_on( |
| 540 | + x: Array, wait_on: object, *, xp: ModuleType | None = None |
| 541 | +) -> Array: # numpydoc ignore=SA04 |
| 542 | + """ |
| 543 | + Pause materialization of `x` until `wait_on` has been materialized. |
| 544 | +
|
| 545 | + This is typically used to collect multiple calls to `lazy_raise` and/or |
| 546 | + `lazy_warn` from validation functions that would otherwise return None. |
| 547 | + If `wait_on` is not a lazy array, just return `x`. |
| 548 | +
|
| 549 | + Read `lazy_raise` for detailed explanation. |
| 550 | +
|
| 551 | + Parameters |
| 552 | + ---------- |
| 553 | + x : Array |
| 554 | + Any one Array, potentially lazy, that is used later on to produce the value |
| 555 | + returned by your function. |
| 556 | + wait_on : object |
| 557 | + Any object. If it's a lazy array, block the materialization of `x` until |
| 558 | + `wait_on` has been fully materialized. |
| 559 | + xp : array_namespace, optional |
| 560 | + The standard-compatible namespace for `x`. Default: infer. |
| 561 | +
|
| 562 | + Returns |
| 563 | + ------- |
| 564 | + Array |
| 565 | + `x`. If both `x` and `wait_on` are lazy arrays, the graph |
| 566 | + underlying `x` is altered to wait until `wait_on` has been materialized. |
| 567 | + If `wait_on` raises, the exception is propagated to `x`. |
| 568 | +
|
| 569 | + See Also |
| 570 | + -------- |
| 571 | + lazy_apply |
| 572 | + lazy_raise |
| 573 | + lazy_warn |
| 574 | + dask.graph_manipulation.wait_on |
| 575 | +
|
| 576 | + Examples |
| 577 | + -------- |
| 578 | + :: |
| 579 | +
|
| 580 | + def validate(x, xp): |
| 581 | + # Future that evaluates the checks. Contents are inconsequential. |
| 582 | + # Avoid zero-sized arrays, as they may be elided by the graph optimizer. |
| 583 | + future = xp.empty(1) |
| 584 | + future = lazy_raise(future, xp.any(x < 10), ValueError("Less than 10")) |
| 585 | + future = lazy_warn(future, xp.any(x > 20), UserWarning, "More than 20")) |
| 586 | + return future |
| 587 | +
|
| 588 | + def f(x, xp): |
| 589 | + x = lazy_wait_on(x, validate(x, xp), xp=xp) |
| 590 | + return x + 1 |
| 591 | + """ |
| 592 | + |
| 593 | + def _lazy_wait_on(x: Array, _: Array) -> Array: # numpydoc ignore=PR01,RT01 |
| 594 | + """Eager helper of `lazy_wait_on` running inside the lazy graph.""" |
| 595 | + return x |
| 596 | + |
| 597 | + return _lazy_wait_on_impl(x, wait_on, _lazy_wait_on, xp=xp) |
| 598 | + |
| 599 | + |
| 600 | +def _lazy_wait_on_impl( # numpydoc ignore=PR01,RT01 |
| 601 | + x: Array, |
| 602 | + wait_on: object, |
| 603 | + eager_func: Callable[[Array, Array], Array], |
| 604 | + xp: ModuleType | None, |
| 605 | +) -> Array: |
| 606 | + """Implementation of lazy_raise, lazy_warn, and lazy_wait_on.""" |
| 607 | + if not is_lazy_array(wait_on): |
| 608 | + return eager_func(x, wait_on) |
| 609 | + |
| 610 | + if cast(Array, wait_on).shape != (): |
| 611 | + msg = "cond/wait_on must be 0-dimensional" |
| 612 | + raise ValueError(msg) |
| 613 | + |
| 614 | + if xp is None: |
| 615 | + xp = array_namespace(x, wait_on) |
| 616 | + |
| 617 | + if is_dask_namespace(xp): |
| 618 | + # lazy_apply would rechunk x |
| 619 | + return xp.map_blocks(eager_func, x, wait_on, dtype=x.dtype, meta=x._meta) # pylint: disable=protected-access |
| 620 | + |
| 621 | + return lazy_apply(eager_func, x, wait_on, shape=x.shape, dtype=x.dtype, xp=xp) |
0 commit comments