-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: apply_where
(migrate lazywhere
from scipy)
#141
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
:nosignatures: | ||
:toctree: generated | ||
|
||
apply_where | ||
at | ||
atleast_nd | ||
broadcast_shapes | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,17 +5,23 @@ | |
|
||
import math | ||
import warnings | ||
from collections.abc import Sequence | ||
from types import ModuleType | ||
from typing import cast | ||
from collections.abc import Callable, Sequence | ||
from types import ModuleType, NoneType | ||
from typing import cast, overload | ||
|
||
from ._at import at | ||
from ._utils import _compat, _helpers | ||
from ._utils._compat import array_namespace, is_jax_array | ||
from ._utils._helpers import asarrays, eager_shape, ndindex | ||
from ._utils._compat import ( | ||
array_namespace, | ||
is_dask_namespace, | ||
is_jax_array, | ||
is_jax_namespace, | ||
) | ||
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex | ||
from ._utils._typing import Array | ||
|
||
__all__ = [ | ||
"apply_where", | ||
"atleast_nd", | ||
"broadcast_shapes", | ||
"cov", | ||
|
@@ -29,6 +35,148 @@ | |
] | ||
|
||
|
||
@overload | ||
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 | ||
cond: Array, | ||
args: Array | tuple[Array, ...], | ||
f1: Callable[..., Array], | ||
f2: Callable[..., Array], | ||
/, | ||
*, | ||
xp: ModuleType | None = None, | ||
) -> Array: ... | ||
|
||
|
||
@overload | ||
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 | ||
cond: Array, | ||
args: Array | tuple[Array, ...], | ||
f1: Callable[..., Array], | ||
/, | ||
*, | ||
fill_value: Array | complex, | ||
xp: ModuleType | None = None, | ||
) -> Array: ... | ||
|
||
|
||
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02 | ||
cond: Array, | ||
args: Array | tuple[Array, ...], | ||
f1: Callable[..., Array], | ||
f2: Callable[..., Array] | None = None, | ||
/, | ||
*, | ||
fill_value: Array | complex | None = None, | ||
xp: ModuleType | None = None, | ||
) -> Array: | ||
""" | ||
Run one of two elementwise functions depending on a condition. | ||
|
||
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise | ||
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``. | ||
|
||
Parameters | ||
---------- | ||
cond : array | ||
The condition, expressed as a boolean array. | ||
args : Array or tuple of Arrays | ||
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`. | ||
f1 : callable | ||
Elementwise function of `args`, returning a single array. | ||
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``. | ||
f2 : callable, optional | ||
Elementwise function of `args`, returning a single array. | ||
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``. | ||
Mutually exclusive with `fill_value`. | ||
fill_value : Array or scalar, optional | ||
If provided, value with which to fill output array where `cond` is False. | ||
It does not need to be scalar; it needs however to be broadcastable with | ||
`cond` and `args`. | ||
Mutually exclusive with `f2`. You must provide one or the other. | ||
xp : array_namespace, optional | ||
The standard-compatible namespace for `cond` and `args`. Default: infer. | ||
|
||
Returns | ||
------- | ||
Array | ||
An array with elements from the output of `f1` where `cond` is True and either | ||
the output of `f2` or `fill_value` where `cond` is False. The returned array has | ||
data type determined by type promotion rules between the output of `f1` and | ||
either `fill_value` or the output of `f2`. | ||
|
||
Notes | ||
----- | ||
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even | ||
when `cond` is False, and `f2` when cond is True. This function evaluates each | ||
function only for their matching condition, if the backend allows for it. | ||
|
||
On Dask, `f1` and `f2` are applied to the individual chunks and should use functions | ||
from the namespace of the chunks. | ||
|
||
Examples | ||
-------- | ||
>>> import array_api_strict as xp | ||
>>> import array_api_extra as xpx | ||
>>> a = xp.asarray([5, 4, 3]) | ||
>>> b = xp.asarray([0, 2, 2]) | ||
>>> def f(a, b): | ||
... return a // b | ||
>>> xpx.apply_where(b != 0, (a, b), f, fill_value=xp.nan) | ||
array([ nan, 2., 1.]) | ||
""" | ||
# Parse and normalize arguments | ||
if (f2 is None) == (fill_value is None): | ||
msg = "Exactly one of `fill_value` or `f2` must be given." | ||
raise TypeError(msg) | ||
args_ = list(args) if isinstance(args, tuple) else [args] | ||
del args | ||
|
||
xp = array_namespace(cond, fill_value, *args_) if xp is None else xp | ||
|
||
if isinstance(fill_value, int | float | complex | NoneType): | ||
cond, *args_ = xp.broadcast_arrays(cond, *args_) | ||
else: | ||
cond, fill_value, *args_ = xp.broadcast_arrays(cond, fill_value, *args_) | ||
|
||
if is_dask_namespace(xp): | ||
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp) | ||
# map_blocks doesn't descend into tuples of Arrays | ||
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp) | ||
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp) | ||
|
||
|
||
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 | ||
cond: Array, | ||
f1: Callable[..., Array], | ||
f2: Callable[..., Array] | None, | ||
fill_value: Array | int | float | complex | bool | None, | ||
*args: Array, | ||
xp: ModuleType, | ||
) -> Array: | ||
"""Helper of `apply_where`. On Dask, this runs on a single chunk.""" | ||
|
||
if is_jax_namespace(xp): | ||
# jax.jit does not support assignment by boolean mask | ||
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value) | ||
|
||
temp1 = f1(*(arr[cond] for arr in args)) | ||
|
||
if f2 is None: | ||
dtype = xp.result_type(temp1, fill_value) | ||
if isinstance(fill_value, int | float | complex): | ||
out = xp.full_like(cond, dtype=dtype, fill_value=fill_value) | ||
else: | ||
out = xp.astype(fill_value, dtype, copy=True) | ||
else: | ||
ncond = ~cond | ||
temp2 = f2(*(arr[ncond] for arr in args)) | ||
dtype = xp.result_type(temp1, temp2) | ||
out = xp.empty_like(cond, dtype=dtype) | ||
out = at(out, ncond).set(temp2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JAX doesn't benefit from this |
||
|
||
return at(out, cond).set(temp1) | ||
|
||
|
||
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: | ||
""" | ||
Recursively expand the dimension of an array to at least `ndim`. | ||
|
@@ -393,12 +541,15 @@ def isclose( | |
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating")) | ||
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating")) | ||
if a_inexact or b_inexact: | ||
# FIXME: use scipy's lazywhere to suppress warnings on inf | ||
out = xp.where( | ||
# prevent warnings on numpy and dask on inf - inf | ||
mxp = meta_namespace(a, b, xp=xp) | ||
out = apply_where( | ||
xp.isinf(a) | xp.isinf(b), | ||
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)), | ||
(a, b), | ||
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType] | ||
# Note: inf <= inf is True! | ||
xp.abs(a - b) <= (atol + rtol * xp.abs(b)), | ||
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType] | ||
xp=xp, | ||
) | ||
if equal_nan: | ||
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX-on-dask is currently unsupported by Dask. This is here and not much higher above only for this reason.
We'll need to add explicit unit tests when it lands in the future.