Skip to content

Commit f74d71a

Browse files
authored
ENH: apply_where (migrate lazywhere from scipy) (#141)
* ENH: apply_where (migrate _lazywhere from scipy) * Code review * merge main * tweak sparse skip
1 parent 904a7e2 commit f74d71a

File tree

10 files changed

+590
-35
lines changed

10 files changed

+590
-35
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
apply_where
910
at
1011
atleast_nd
1112
broadcast_shapes

pixi.lock

+106-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ numpydoc = ">=1.8.0,<2"
6464
array-api-strict = "*"
6565
numpy = "*"
6666
pytest = "*"
67+
hypothesis = "*"
6768
dask-core = "*" # No distributed, tornado, etc.
6869
# NOTE: don't add jax, pytorch, sparse, cupy here
6970
# as they slow down mypy and are not portable across target OSs
@@ -79,6 +80,7 @@ lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] }
7980
[tool.pixi.feature.tests.dependencies]
8081
pytest = ">=6"
8182
pytest-cov = ">=3"
83+
hypothesis = "*"
8284
array-api-strict = "*"
8385
numpy = "*"
8486

@@ -231,6 +233,10 @@ reportMissingTypeStubs = false
231233
reportUnreachable = false
232234
# ruff handles this
233235
reportUnusedParameter = false
236+
# cyclic imports inside function bodies
237+
reportImportCycles = false
238+
# PyRight can't trace types in lambdas
239+
reportUnknownLambdaType = false
234240

235241
executionEnvironments = [
236242
{ root = "tests", reportPrivateUsage = false },

src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
6+
apply_where,
67
atleast_nd,
78
broadcast_shapes,
89
cov,
@@ -19,6 +20,7 @@
1920
# pylint: disable=duplicate-code
2021
__all__ = [
2122
"__version__",
23+
"apply_where",
2224
"at",
2325
"atleast_nd",
2426
"broadcast_shapes",

src/array_api_extra/_lib/_at.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
is_jax_array,
1616
is_writeable_array,
1717
)
18+
from ._utils._helpers import meta_namespace
1819
from ._utils._typing import Array, SetIndex
1920

2021

@@ -263,6 +264,8 @@ def _op(
263264
Array
264265
Updated `x`.
265266
"""
267+
from ._funcs import apply_where # pylint: disable=cyclic-import
268+
266269
x, idx = self._x, self._idx
267270
xp = array_namespace(x, y) if xp is None else xp
268271

@@ -295,8 +298,10 @@ def _op(
295298
y_xp = xp.asarray(y, dtype=x.dtype)
296299
if y_xp.ndim == 0:
297300
if out_of_place_op: # add(), subtract(), ...
298-
# FIXME: suppress inf warnings on dask with lazywhere
299-
out = xp.where(idx, out_of_place_op(x, y_xp), x)
301+
# suppress inf warnings on Dask
302+
out = apply_where(
303+
idx, (x, y_xp), out_of_place_op, fill_value=x, xp=xp
304+
)
300305
# Undo int->float promotion on JAX after _AtOp.DIVIDE
301306
out = xp.astype(out, x.dtype, copy=False)
302307
else: # set()
@@ -420,9 +425,16 @@ def min(
420425
xp: ModuleType | None = None,
421426
) -> Array: # numpydoc ignore=PR01,RT01
422427
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
428+
# On Dask, this function runs on the chunks, so we need to determine the
429+
# namespace that Dask is wrapping.
430+
# Note that da.minimum _incidentally_ works on numpy, cupy, and sparse
431+
# thanks to all these meta-namespaces implementing the __array_ufunc__
432+
# interface, but there's no guarantee that it will work for other
433+
# wrapped libraries in the future.
423434
xp = array_namespace(self._x) if xp is None else xp
435+
mxp = meta_namespace(self._x, xp=xp)
424436
y = xp.asarray(y)
425-
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp)
437+
return self._op(_AtOp.MIN, mxp.minimum, mxp.minimum, y, copy=copy, xp=xp)
426438

427439
def max(
428440
self,
@@ -432,6 +444,8 @@ def max(
432444
xp: ModuleType | None = None,
433445
) -> Array: # numpydoc ignore=PR01,RT01
434446
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
447+
# See note on min()
435448
xp = array_namespace(self._x) if xp is None else xp
449+
mxp = meta_namespace(self._x, xp=xp)
436450
y = xp.asarray(y)
437-
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp)
451+
return self._op(_AtOp.MAX, mxp.maximum, mxp.maximum, y, copy=copy, xp=xp)

src/array_api_extra/_lib/_funcs.py

+160-9
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,23 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
9-
from types import ModuleType
10-
from typing import cast
8+
from collections.abc import Callable, Sequence
9+
from types import ModuleType, NoneType
10+
from typing import cast, overload
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays, eager_shape, ndindex
14+
from ._utils._compat import (
15+
array_namespace,
16+
is_dask_namespace,
17+
is_jax_array,
18+
is_jax_namespace,
19+
)
20+
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex
1621
from ._utils._typing import Array
1722

1823
__all__ = [
24+
"apply_where",
1925
"atleast_nd",
2026
"broadcast_shapes",
2127
"cov",
@@ -29,6 +35,148 @@
2935
]
3036

3137

38+
@overload
39+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
40+
cond: Array,
41+
args: Array | tuple[Array, ...],
42+
f1: Callable[..., Array],
43+
f2: Callable[..., Array],
44+
/,
45+
*,
46+
xp: ModuleType | None = None,
47+
) -> Array: ...
48+
49+
50+
@overload
51+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
52+
cond: Array,
53+
args: Array | tuple[Array, ...],
54+
f1: Callable[..., Array],
55+
/,
56+
*,
57+
fill_value: Array | complex,
58+
xp: ModuleType | None = None,
59+
) -> Array: ...
60+
61+
62+
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
63+
cond: Array,
64+
args: Array | tuple[Array, ...],
65+
f1: Callable[..., Array],
66+
f2: Callable[..., Array] | None = None,
67+
/,
68+
*,
69+
fill_value: Array | complex | None = None,
70+
xp: ModuleType | None = None,
71+
) -> Array:
72+
"""
73+
Run one of two elementwise functions depending on a condition.
74+
75+
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise
76+
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``.
77+
78+
Parameters
79+
----------
80+
cond : array
81+
The condition, expressed as a boolean array.
82+
args : Array or tuple of Arrays
83+
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
84+
f1 : callable
85+
Elementwise function of `args`, returning a single array.
86+
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
87+
f2 : callable, optional
88+
Elementwise function of `args`, returning a single array.
89+
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
90+
Mutually exclusive with `fill_value`.
91+
fill_value : Array or scalar, optional
92+
If provided, value with which to fill output array where `cond` is False.
93+
It does not need to be scalar; it needs however to be broadcastable with
94+
`cond` and `args`.
95+
Mutually exclusive with `f2`. You must provide one or the other.
96+
xp : array_namespace, optional
97+
The standard-compatible namespace for `cond` and `args`. Default: infer.
98+
99+
Returns
100+
-------
101+
Array
102+
An array with elements from the output of `f1` where `cond` is True and either
103+
the output of `f2` or `fill_value` where `cond` is False. The returned array has
104+
data type determined by type promotion rules between the output of `f1` and
105+
either `fill_value` or the output of `f2`.
106+
107+
Notes
108+
-----
109+
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even
110+
when `cond` is False, and `f2` when cond is True. This function evaluates each
111+
function only for their matching condition, if the backend allows for it.
112+
113+
On Dask, `f1` and `f2` are applied to the individual chunks and should use functions
114+
from the namespace of the chunks.
115+
116+
Examples
117+
--------
118+
>>> import array_api_strict as xp
119+
>>> import array_api_extra as xpx
120+
>>> a = xp.asarray([5, 4, 3])
121+
>>> b = xp.asarray([0, 2, 2])
122+
>>> def f(a, b):
123+
... return a // b
124+
>>> xpx.apply_where(b != 0, (a, b), f, fill_value=xp.nan)
125+
array([ nan, 2., 1.])
126+
"""
127+
# Parse and normalize arguments
128+
if (f2 is None) == (fill_value is None):
129+
msg = "Exactly one of `fill_value` or `f2` must be given."
130+
raise TypeError(msg)
131+
args_ = list(args) if isinstance(args, tuple) else [args]
132+
del args
133+
134+
xp = array_namespace(cond, fill_value, *args_) if xp is None else xp
135+
136+
if isinstance(fill_value, int | float | complex | NoneType):
137+
cond, *args_ = xp.broadcast_arrays(cond, *args_)
138+
else:
139+
cond, fill_value, *args_ = xp.broadcast_arrays(cond, fill_value, *args_)
140+
141+
if is_dask_namespace(xp):
142+
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
143+
# map_blocks doesn't descend into tuples of Arrays
144+
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
145+
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
146+
147+
148+
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
149+
cond: Array,
150+
f1: Callable[..., Array],
151+
f2: Callable[..., Array] | None,
152+
fill_value: Array | int | float | complex | bool | None,
153+
*args: Array,
154+
xp: ModuleType,
155+
) -> Array:
156+
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
157+
158+
if is_jax_namespace(xp):
159+
# jax.jit does not support assignment by boolean mask
160+
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
161+
162+
temp1 = f1(*(arr[cond] for arr in args))
163+
164+
if f2 is None:
165+
dtype = xp.result_type(temp1, fill_value)
166+
if isinstance(fill_value, int | float | complex):
167+
out = xp.full_like(cond, dtype=dtype, fill_value=fill_value)
168+
else:
169+
out = xp.astype(fill_value, dtype, copy=True)
170+
else:
171+
ncond = ~cond
172+
temp2 = f2(*(arr[ncond] for arr in args))
173+
dtype = xp.result_type(temp1, temp2)
174+
out = xp.empty_like(cond, dtype=dtype)
175+
out = at(out, ncond).set(temp2)
176+
177+
return at(out, cond).set(temp1)
178+
179+
32180
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
33181
"""
34182
Recursively expand the dimension of an array to at least `ndim`.
@@ -393,12 +541,15 @@ def isclose(
393541
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
394542
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
395543
if a_inexact or b_inexact:
396-
# FIXME: use scipy's lazywhere to suppress warnings on inf
397-
out = xp.where(
544+
# prevent warnings on numpy and dask on inf - inf
545+
mxp = meta_namespace(a, b, xp=xp)
546+
out = apply_where(
398547
xp.isinf(a) | xp.isinf(b),
399-
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)),
548+
(a, b),
549+
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType]
400550
# Note: inf <= inf is True!
401-
xp.abs(a - b) <= (atol + rtol * xp.abs(b)),
551+
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType]
552+
xp=xp,
402553
)
403554
if equal_nan:
404555
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)

src/array_api_extra/_lib/_utils/_helpers.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,27 @@
99
from typing import TYPE_CHECKING, cast
1010

1111
from . import _compat
12-
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
12+
from ._compat import (
13+
array_namespace,
14+
is_array_api_obj,
15+
is_dask_namespace,
16+
is_numpy_array,
17+
)
1318
from ._typing import Array
1419

1520
if TYPE_CHECKING: # pragma: no cover
1621
# TODO import from typing (requires Python >=3.13)
1722
from typing_extensions import TypeIs
1823

1924

20-
__all__ = ["asarrays", "eager_shape", "in1d", "is_python_scalar", "mean"]
25+
__all__ = [
26+
"asarrays",
27+
"eager_shape",
28+
"in1d",
29+
"is_python_scalar",
30+
"mean",
31+
"meta_namespace",
32+
]
2133

2234

2335
def in1d(
@@ -230,3 +242,33 @@ def eager_shape(x: Array, /) -> tuple[int, ...]:
230242
msg = "Unsupported lazy shape"
231243
raise TypeError(msg)
232244
return cast(tuple[int, ...], shape)
245+
246+
247+
def meta_namespace(
248+
*arrays: Array | int | float | complex | bool | None,
249+
xp: ModuleType | None = None,
250+
) -> ModuleType:
251+
"""
252+
Get the namespace of Dask chunks.
253+
254+
On all other backends, just return the namespace of the arrays.
255+
256+
Parameters
257+
----------
258+
*arrays : Array | int | float | complex | bool | None
259+
Input arrays.
260+
xp : array_namespace, optional
261+
The standard-compatible namespace for the input arrays. Default: infer.
262+
263+
Returns
264+
-------
265+
array_namespace
266+
If xp is Dask, the namespace of the Dask chunks;
267+
otherwise, the namespace of the arrays.
268+
"""
269+
xp = array_namespace(*arrays) if xp is None else xp
270+
if not is_dask_namespace(xp):
271+
return xp
272+
# Quietly skip scalars and None's
273+
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
274+
return array_namespace(*metas)

tests/test_at.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,7 @@ def test_bool_mask_nd(xp: ModuleType):
277277
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
278278

279279

280-
@pytest.mark.parametrize(
281-
"bool_mask",
282-
[
283-
False,
284-
pytest.param(
285-
True,
286-
marks=pytest.mark.xfail_xp_backend(
287-
Backend.DASK, reason="FIXME need scipy's lazywhere"
288-
),
289-
),
290-
],
291-
)
280+
@pytest.mark.parametrize("bool_mask", [False, True])
292281
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
293282
x = xp.asarray([math.inf, 1.0, 2.0])
294283
idx = ~xp.isinf(x) if bool_mask else slice(1, None)

0 commit comments

Comments
 (0)