Skip to content

Commit 0523d9f

Browse files
committed
ENH: apply_where (migrate _lazywhere from scipy)
1 parent 308fc1f commit 0523d9f

File tree

10 files changed

+588
-32
lines changed

10 files changed

+588
-32
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
apply_where
910
at
1011
atleast_nd
1112
broadcast_shapes

pixi.lock

+106-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ numpydoc = ">=1.8.0,<2"
6464
array-api-strict = "*"
6565
numpy = "*"
6666
pytest = "*"
67+
hypothesis = "*"
6768
dask-core = "*" # No distributed, tornado, etc.
6869
# NOTE: don't add jax, pytorch, sparse, cupy here
6970
# as they slow down mypy and are not portable across target OSs
@@ -79,6 +80,7 @@ lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] }
7980
[tool.pixi.feature.tests.dependencies]
8081
pytest = ">=6"
8182
pytest-cov = ">=3"
83+
hypothesis = "*"
8284
array-api-strict = "*"
8385
numpy = "*"
8486

@@ -232,6 +234,10 @@ reportMissingTypeStubs = false
232234
reportUnreachable = false
233235
# ruff handles this
234236
reportUnusedParameter = false
237+
# cyclic imports inside function bodies
238+
reportImportCycles = false
239+
# PyRight can't trace types in lambdas
240+
reportUnknownLambdaType = false
235241

236242
executionEnvironments = [
237243
{ root = "tests", reportPrivateUsage = false },

src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
6+
apply_where,
67
atleast_nd,
78
broadcast_shapes,
89
cov,
@@ -19,6 +20,7 @@
1920
# pylint: disable=duplicate-code
2021
__all__ = [
2122
"__version__",
23+
"apply_where",
2224
"at",
2325
"atleast_nd",
2426
"broadcast_shapes",

src/array_api_extra/_lib/_at.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
is_jax_array,
1616
is_writeable_array,
1717
)
18+
from ._utils._helpers import meta_namespace
1819
from ._utils._typing import Array, SetIndex
1920

2021

@@ -263,6 +264,8 @@ def _op(
263264
Array
264265
Updated `x`.
265266
"""
267+
from ._funcs import apply_where # pylint: disable=cyclic-import
268+
266269
x, idx = self._x, self._idx
267270
xp = array_namespace(x, y) if xp is None else xp
268271

@@ -295,8 +298,10 @@ def _op(
295298
y_xp = xp.asarray(y, dtype=x.dtype)
296299
if y_xp.ndim == 0:
297300
if out_of_place_op: # add(), subtract(), ...
298-
# FIXME: suppress inf warnings on dask with lazywhere
299-
out = xp.where(idx, out_of_place_op(x, y_xp), x)
301+
# suppress inf warnings on Dask
302+
out = apply_where(
303+
idx, (x, y_xp), out_of_place_op, fill_value=x, xp=xp
304+
)
300305
# Undo int->float promotion on JAX after _AtOp.DIVIDE
301306
out = xp.astype(out, x.dtype, copy=False)
302307
else: # set()
@@ -420,9 +425,16 @@ def min(
420425
xp: ModuleType | None = None,
421426
) -> Array: # numpydoc ignore=PR01,RT01
422427
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
428+
# On Dask, this function runs on the chunks, so we need to determine the
429+
# namespace that Dask is wrapping.
430+
# Note that da.minimum _incidentally_ works on numpy, cupy, and sparse
431+
# thanks to all these meta-namespaces implementing the __array_ufunc__
432+
# interface, but there's no guarantee that it will work for other
433+
# wrapped libraries in the future.
423434
xp = array_namespace(self._x) if xp is None else xp
435+
mxp = meta_namespace(self._x, xp=xp)
424436
y = xp.asarray(y)
425-
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp)
437+
return self._op(_AtOp.MIN, mxp.minimum, mxp.minimum, y, copy=copy, xp=xp)
426438

427439
def max(
428440
self,
@@ -432,6 +444,8 @@ def max(
432444
xp: ModuleType | None = None,
433445
) -> Array: # numpydoc ignore=PR01,RT01
434446
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
447+
# See note on min()
435448
xp = array_namespace(self._x) if xp is None else xp
449+
mxp = meta_namespace(self._x, xp=xp)
436450
y = xp.asarray(y)
437-
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp)
451+
return self._op(_AtOp.MAX, mxp.maximum, mxp.maximum, y, copy=copy, xp=xp)

src/array_api_extra/_lib/_funcs.py

+157-8
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,23 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
8+
from collections.abc import Callable, Sequence
99
from types import ModuleType
10-
from typing import cast
10+
from typing import cast, overload
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays, eager_shape, ndindex
14+
from ._utils._compat import (
15+
array_namespace,
16+
is_dask_namespace,
17+
is_jax_array,
18+
is_jax_namespace,
19+
)
20+
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex
1621
from ._utils._typing import Array
1722

1823
__all__ = [
24+
"apply_where",
1925
"atleast_nd",
2026
"broadcast_shapes",
2127
"cov",
@@ -29,6 +35,146 @@
2935
]
3036

3137

38+
@overload
39+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
40+
cond: Array,
41+
args: Array | tuple[Array, ...],
42+
f1: Callable[..., Array],
43+
f2: Callable[..., Array],
44+
/,
45+
*,
46+
xp: ModuleType | None = None,
47+
) -> Array: ...
48+
49+
50+
@overload
51+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
52+
cond: Array,
53+
args: Array | tuple[Array, ...],
54+
f1: Callable[..., Array],
55+
/,
56+
*,
57+
fill_value: Array | int | float | complex | bool,
58+
xp: ModuleType | None = None,
59+
) -> Array: ...
60+
61+
62+
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
63+
cond: Array,
64+
args: Array | tuple[Array, ...],
65+
f1: Callable[..., Array],
66+
f2: Callable[..., Array] | None = None,
67+
/,
68+
*,
69+
fill_value: Array | int | float | complex | bool | None = None,
70+
xp: ModuleType | None = None,
71+
) -> Array:
72+
"""
73+
Run one of two elementwise functions depending on a condition.
74+
75+
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise
76+
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``.
77+
78+
Parameters
79+
----------
80+
cond : array
81+
The condition, expressed as a boolean array.
82+
args : Array or tuple of Arrays
83+
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
84+
f1 : callable
85+
Elementwise function of `args`, returning a single array.
86+
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
87+
f2 : callable, optional
88+
Elementwise function of `args`, returning a single array.
89+
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
90+
Mutually exclusive with `fill_value`.
91+
fill_value : Array or scalar, optional
92+
If provided, value with which to fill output array where `cond` is False.
93+
It does not need to be scalar; it needs however to be broadcastable with
94+
`cond` and `args`.
95+
Mutually exclusive with `f2`. You must provide one or the other.
96+
xp : array_namespace, optional
97+
The standard-compatible namespace for `cond` and `args`. Default: infer.
98+
99+
Returns
100+
-------
101+
Array
102+
An array with elements from the output of `f1` where `cond` is True and either
103+
the output of `f2` or `fill_value` where `cond` is False. The returned array has
104+
data type determined by type promotion rules between the output of `f1` and
105+
either `fill_value` or the output of `f2`.
106+
107+
Notes
108+
-----
109+
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even
110+
when `cond` is False, and `f2` when cond is True. This function evaluates each
111+
function only for their matching condition, if the backend allows for it.
112+
113+
On Dask, `f1` and `f2` are applied to the individual chunks and should use functions
114+
from the namespace of the chunks.
115+
116+
Examples
117+
--------
118+
>>> a = xp.asarray([5, 4, 3])
119+
>>> b = xp.asarray([0, 2, 2])
120+
>>> def f(a, b):
121+
... return a // b
122+
>>> apply_where(b != 0, (a, b), f, fill_value=xp.nan)
123+
array([ nan, 2., 1.])
124+
"""
125+
# Parse and normalize arguments
126+
if (f2 is None) == (fill_value is None):
127+
msg = "Exactly one of `fill_value` or `f2` must be given."
128+
raise TypeError(msg)
129+
args_ = list(args) if isinstance(args, tuple) else [args]
130+
del args
131+
132+
xp = array_namespace(cond, *args_) if xp is None else xp
133+
134+
if getattr(fill_value, "ndim", 0):
135+
cond, fill_value, *args_ = xp.broadcast_arrays(cond, fill_value, *args_)
136+
else:
137+
cond, *args_ = xp.broadcast_arrays(cond, *args_)
138+
139+
if is_dask_namespace(xp):
140+
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
141+
# map_blocks doesn't descend into tuples of Arrays
142+
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
143+
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
144+
145+
146+
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
147+
cond: Array,
148+
f1: Callable[..., Array],
149+
f2: Callable[..., Array] | None,
150+
fill_value: Array | int | float | complex | bool | None,
151+
*args: Array,
152+
xp: ModuleType,
153+
) -> Array:
154+
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
155+
156+
if is_jax_namespace(xp):
157+
# jax.jit does not support assignment by boolean mask
158+
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
159+
160+
temp1 = f1(*(arr[cond] for arr in args))
161+
162+
if f2 is None:
163+
dtype = xp.result_type(temp1, fill_value)
164+
if getattr(fill_value, "ndim", 0):
165+
out = xp.astype(fill_value, dtype, copy=True)
166+
else:
167+
out = xp.full_like(cond, dtype=dtype, fill_value=fill_value)
168+
else:
169+
ncond = ~cond
170+
temp2 = f2(*(arr[ncond] for arr in args))
171+
dtype = xp.result_type(temp1, temp2)
172+
out = xp.empty_like(cond, dtype=dtype)
173+
out = at(out, ncond).set(temp2)
174+
175+
return at(out, cond).set(temp1)
176+
177+
32178
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
33179
"""
34180
Recursively expand the dimension of an array to at least `ndim`.
@@ -393,12 +539,15 @@ def isclose(
393539
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
394540
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
395541
if a_inexact or b_inexact:
396-
# FIXME: use scipy's lazywhere to suppress warnings on inf
397-
out = xp.where(
542+
# prevent warnings on numpy and dask on inf - inf
543+
mxp = meta_namespace(a, b, xp=xp)
544+
out = apply_where(
398545
xp.isinf(a) | xp.isinf(b),
399-
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)),
546+
(a, b),
547+
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType]
400548
# Note: inf <= inf is True!
401-
xp.abs(a - b) <= (atol + rtol * xp.abs(b)),
549+
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType]
550+
xp=xp,
402551
)
403552
if equal_nan:
404553
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)

src/array_api_extra/_lib/_utils/_helpers.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from typing import TYPE_CHECKING, cast
1010

1111
from . import _compat
12-
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
12+
from ._compat import (
13+
array_namespace,
14+
is_array_api_obj,
15+
is_dask_namespace,
16+
is_numpy_array,
17+
)
1318
from ._typing import Array
1419

1520
if TYPE_CHECKING: # pragma: no cover
@@ -18,6 +23,7 @@
1823

1924

2025
__all__ = ["asarrays", "eager_shape", "in1d", "is_python_scalar", "mean"]
26+
__all__ = ["asarrays", "in1d", "is_python_scalar", "mean", "meta_namespace"]
2127

2228

2329
def in1d(
@@ -230,3 +236,33 @@ def eager_shape(x: Array, /) -> tuple[int, ...]:
230236
msg = "Unsupported lazy shape"
231237
raise TypeError(msg)
232238
return cast(tuple[int, ...], shape)
239+
240+
241+
def meta_namespace(
242+
*arrays: Array | int | float | complex | bool | None,
243+
xp: ModuleType | None = None,
244+
) -> ModuleType:
245+
"""
246+
Get the namespace of Dask chunks.
247+
248+
On all other backends, just return the namespace of the arrays.
249+
250+
Parameters
251+
----------
252+
*arrays : Array | int | float | complex | bool | None
253+
Input arrays.
254+
xp : array_namespace, optional
255+
The standard-compatible namespace for the input arrays. Default: infer.
256+
257+
Returns
258+
-------
259+
array_namespace
260+
If xp is Dask, the namespace of the Dask chunks;
261+
otherwise, the namespace of the arrays.
262+
"""
263+
xp = array_namespace(*arrays) if xp is None else xp
264+
if not is_dask_namespace(xp):
265+
return xp
266+
# Quietly skip scalars and None's
267+
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
268+
return array_namespace(*metas)

tests/test_at.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,7 @@ def test_bool_mask_nd(xp: ModuleType):
277277
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
278278

279279

280-
@pytest.mark.parametrize(
281-
"bool_mask",
282-
[
283-
False,
284-
pytest.param(
285-
True,
286-
marks=pytest.mark.xfail_xp_backend(
287-
Backend.DASK, reason="FIXME need scipy's lazywhere"
288-
),
289-
),
290-
],
291-
)
280+
@pytest.mark.parametrize("bool_mask", [False, True])
292281
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
293282
x = xp.asarray([math.inf, 1.0, 2.0])
294283
idx = ~xp.isinf(x) if bool_mask else slice(1, None)

0 commit comments

Comments
 (0)