Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: apply_where (migrate lazywhere from scipy) #141

Merged
merged 5 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated

apply_where
at
atleast_nd
broadcast_shapes
Expand Down
107 changes: 106 additions & 1 deletion pixi.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ numpydoc = ">=1.8.0,<2"
array-api-strict = "*"
numpy = "*"
pytest = "*"
hypothesis = "*"
dask-core = "*" # No distributed, tornado, etc.
# NOTE: don't add jax, pytorch, sparse, cupy here
# as they slow down mypy and are not portable across target OSs
Expand All @@ -79,6 +80,7 @@ lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] }
[tool.pixi.feature.tests.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
hypothesis = "*"
array-api-strict = "*"
numpy = "*"

Expand Down Expand Up @@ -231,6 +233,10 @@ reportMissingTypeStubs = false
reportUnreachable = false
# ruff handles this
reportUnusedParameter = false
# cyclic imports inside function bodies
reportImportCycles = false
# PyRight can't trace types in lambdas
reportUnknownLambdaType = false

executionEnvironments = [
{ root = "tests", reportPrivateUsage = false },
Expand Down
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ._delegation import isclose, pad
from ._lib._at import at
from ._lib._funcs import (
apply_where,
atleast_nd,
broadcast_shapes,
cov,
Expand All @@ -19,6 +20,7 @@
# pylint: disable=duplicate-code
__all__ = [
"__version__",
"apply_where",
"at",
"atleast_nd",
"broadcast_shapes",
Expand Down
22 changes: 18 additions & 4 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_jax_array,
is_writeable_array,
)
from ._utils._helpers import meta_namespace
from ._utils._typing import Array, SetIndex


Expand Down Expand Up @@ -263,6 +264,8 @@ def _op(
Array
Updated `x`.
"""
from ._funcs import apply_where # pylint: disable=cyclic-import

x, idx = self._x, self._idx
xp = array_namespace(x, y) if xp is None else xp

Expand Down Expand Up @@ -295,8 +298,10 @@ def _op(
y_xp = xp.asarray(y, dtype=x.dtype)
if y_xp.ndim == 0:
if out_of_place_op: # add(), subtract(), ...
# FIXME: suppress inf warnings on dask with lazywhere
out = xp.where(idx, out_of_place_op(x, y_xp), x)
# suppress inf warnings on Dask
out = apply_where(
idx, (x, y_xp), out_of_place_op, fill_value=x, xp=xp
)
# Undo int->float promotion on JAX after _AtOp.DIVIDE
out = xp.astype(out, x.dtype, copy=False)
else: # set()
Expand Down Expand Up @@ -420,9 +425,16 @@ def min(
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
# On Dask, this function runs on the chunks, so we need to determine the
# namespace that Dask is wrapping.
# Note that da.minimum _incidentally_ works on numpy, cupy, and sparse
# thanks to all these meta-namespaces implementing the __array_ufunc__
# interface, but there's no guarantee that it will work for other
# wrapped libraries in the future.
xp = array_namespace(self._x) if xp is None else xp
mxp = meta_namespace(self._x, xp=xp)
y = xp.asarray(y)
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp)
return self._op(_AtOp.MIN, mxp.minimum, mxp.minimum, y, copy=copy, xp=xp)

def max(
self,
Expand All @@ -432,6 +444,8 @@ def max(
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
# See note on min()
xp = array_namespace(self._x) if xp is None else xp
mxp = meta_namespace(self._x, xp=xp)
y = xp.asarray(y)
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp)
return self._op(_AtOp.MAX, mxp.maximum, mxp.maximum, y, copy=copy, xp=xp)
169 changes: 160 additions & 9 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,23 @@

import math
import warnings
from collections.abc import Sequence
from types import ModuleType
from typing import cast
from collections.abc import Callable, Sequence
from types import ModuleType, NoneType
from typing import cast, overload

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._helpers import asarrays, eager_shape, ndindex
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
)
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex
from ._utils._typing import Array

__all__ = [
"apply_where",
"atleast_nd",
"broadcast_shapes",
"cov",
Expand All @@ -29,6 +35,148 @@
]


@overload
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
cond: Array,
args: Array | tuple[Array, ...],
f1: Callable[..., Array],
f2: Callable[..., Array],
/,
*,
xp: ModuleType | None = None,
) -> Array: ...


@overload
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
cond: Array,
args: Array | tuple[Array, ...],
f1: Callable[..., Array],
/,
*,
fill_value: Array | complex,
xp: ModuleType | None = None,
) -> Array: ...


def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
cond: Array,
args: Array | tuple[Array, ...],
f1: Callable[..., Array],
f2: Callable[..., Array] | None = None,
/,
*,
fill_value: Array | complex | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Run one of two elementwise functions depending on a condition.

Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``.

Parameters
----------
cond : array
The condition, expressed as a boolean array.
args : Array or tuple of Arrays
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
f1 : callable
Elementwise function of `args`, returning a single array.
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
f2 : callable, optional
Elementwise function of `args`, returning a single array.
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
Mutually exclusive with `fill_value`.
fill_value : Array or scalar, optional
If provided, value with which to fill output array where `cond` is False.
It does not need to be scalar; it needs however to be broadcastable with
`cond` and `args`.
Mutually exclusive with `f2`. You must provide one or the other.
xp : array_namespace, optional
The standard-compatible namespace for `cond` and `args`. Default: infer.

Returns
-------
Array
An array with elements from the output of `f1` where `cond` is True and either
the output of `f2` or `fill_value` where `cond` is False. The returned array has
data type determined by type promotion rules between the output of `f1` and
either `fill_value` or the output of `f2`.

Notes
-----
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even
when `cond` is False, and `f2` when cond is True. This function evaluates each
function only for their matching condition, if the backend allows for it.

On Dask, `f1` and `f2` are applied to the individual chunks and should use functions
from the namespace of the chunks.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> a = xp.asarray([5, 4, 3])
>>> b = xp.asarray([0, 2, 2])
>>> def f(a, b):
... return a // b
>>> xpx.apply_where(b != 0, (a, b), f, fill_value=xp.nan)
array([ nan, 2., 1.])
"""
# Parse and normalize arguments
if (f2 is None) == (fill_value is None):
msg = "Exactly one of `fill_value` or `f2` must be given."
raise TypeError(msg)
args_ = list(args) if isinstance(args, tuple) else [args]
del args

xp = array_namespace(cond, fill_value, *args_) if xp is None else xp

if isinstance(fill_value, int | float | complex | NoneType):
cond, *args_ = xp.broadcast_arrays(cond, *args_)
else:
cond, fill_value, *args_ = xp.broadcast_arrays(cond, fill_value, *args_)

if is_dask_namespace(xp):
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
# map_blocks doesn't descend into tuples of Arrays
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)


def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
cond: Array,
f1: Callable[..., Array],
f2: Callable[..., Array] | None,
fill_value: Array | int | float | complex | bool | None,
*args: Array,
xp: ModuleType,
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

if is_jax_namespace(xp):
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
Comment on lines +158 to +160
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX-on-dask is currently unsupported by Dask. This is here and not much higher above only for this reason.
We'll need to add explicit unit tests when it lands in the future.


temp1 = f1(*(arr[cond] for arr in args))

if f2 is None:
dtype = xp.result_type(temp1, fill_value)
if isinstance(fill_value, int | float | complex):
out = xp.full_like(cond, dtype=dtype, fill_value=fill_value)
else:
out = xp.astype(fill_value, dtype, copy=True)
else:
ncond = ~cond
temp2 = f2(*(arr[ncond] for arr in args))
dtype = xp.result_type(temp1, temp2)
out = xp.empty_like(cond, dtype=dtype)
out = at(out, ncond).set(temp2)
Copy link
Contributor Author

@crusaderky crusaderky Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX doesn't benefit from this at, but Sparse will (eventually)


return at(out, cond).set(temp1)


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.
Expand Down Expand Up @@ -393,12 +541,15 @@ def isclose(
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
if a_inexact or b_inexact:
# FIXME: use scipy's lazywhere to suppress warnings on inf
out = xp.where(
# prevent warnings on numpy and dask on inf - inf
mxp = meta_namespace(a, b, xp=xp)
out = apply_where(
xp.isinf(a) | xp.isinf(b),
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)),
(a, b),
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType]
# Note: inf <= inf is True!
xp.abs(a - b) <= (atol + rtol * xp.abs(b)),
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType]
xp=xp,
)
if equal_nan:
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
Expand Down
46 changes: 44 additions & 2 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,27 @@
from typing import TYPE_CHECKING, cast

from . import _compat
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
from ._compat import (
array_namespace,
is_array_api_obj,
is_dask_namespace,
is_numpy_array,
)
from ._typing import Array

if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.13)
from typing_extensions import TypeIs


__all__ = ["asarrays", "eager_shape", "in1d", "is_python_scalar", "mean"]
__all__ = [
"asarrays",
"eager_shape",
"in1d",
"is_python_scalar",
"mean",
"meta_namespace",
]


def in1d(
Expand Down Expand Up @@ -230,3 +242,33 @@ def eager_shape(x: Array, /) -> tuple[int, ...]:
msg = "Unsupported lazy shape"
raise TypeError(msg)
return cast(tuple[int, ...], shape)


def meta_namespace(
*arrays: Array | int | float | complex | bool | None,
xp: ModuleType | None = None,
) -> ModuleType:
"""
Get the namespace of Dask chunks.

On all other backends, just return the namespace of the arrays.

Parameters
----------
*arrays : Array | int | float | complex | bool | None
Input arrays.
xp : array_namespace, optional
The standard-compatible namespace for the input arrays. Default: infer.

Returns
-------
array_namespace
If xp is Dask, the namespace of the Dask chunks;
otherwise, the namespace of the arrays.
"""
xp = array_namespace(*arrays) if xp is None else xp
if not is_dask_namespace(xp):
return xp
# Quietly skip scalars and None's
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
return array_namespace(*metas)
13 changes: 1 addition & 12 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,18 +277,7 @@ def test_bool_mask_nd(xp: ModuleType):
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


@pytest.mark.parametrize(
"bool_mask",
[
False,
pytest.param(
True,
marks=pytest.mark.xfail_xp_backend(
Backend.DASK, reason="FIXME need scipy's lazywhere"
),
),
],
)
@pytest.mark.parametrize("bool_mask", [False, True])
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
x = xp.asarray([math.inf, 1.0, 2.0])
idx = ~xp.isinf(x) if bool_mask else slice(1, None)
Expand Down
Loading
Loading