Skip to content

Commit 0b7fce1

Browse files
committed
ENH: apply_where (from scipy's lazywhere)
1 parent 573ed3c commit 0b7fce1

13 files changed

+444
-45
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
apply_where
910
at
1011
atleast_nd
1112
broadcast_shapes

pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+5-4
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,6 @@ markers = [
191191

192192
[tool.coverage]
193193
run.source = ["array_api_extra"]
194-
report.exclude_also = [
195-
'\.\.\.',
196-
'if TYPE_CHECKING:',
197-
]
198194

199195
# mypy
200196

@@ -236,6 +232,10 @@ reportMissingTypeStubs = false
236232
reportUnreachable = false
237233
# ruff handles this
238234
reportUnusedParameter = false
235+
# cyclic imports inside function bodies
236+
reportImportCycles = false
237+
# PyRight can't trace types in lambdas
238+
reportUnknownLambdaType = false
239239

240240
executionEnvironments = [
241241
{ root = "tests", reportPrivateUsage = false },
@@ -301,6 +301,7 @@ messages_control.disable = [
301301
"missing-function-docstring", # numpydoc handles this
302302
"import-error", # mypy handles this
303303
"import-outside-toplevel", # optional dependencies
304+
"cyclic-import", # cyclic imports inside function bodies
304305
]
305306

306307

src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
6+
apply_where,
67
atleast_nd,
78
broadcast_shapes,
89
cov,
@@ -19,6 +20,7 @@
1920
# pylint: disable=duplicate-code
2021
__all__ = [
2122
"__version__",
23+
"apply_where",
2224
"at",
2325
"atleast_nd",
2426
"broadcast_shapes",

src/array_api_extra/_lib/_at.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,16 @@ def _op(
249249
Right-hand side of the operation.
250250
copy : bool or None
251251
Whether to copy the input array. See the class docstring for details.
252-
xp : array_namespace or None
253-
The array namespace for the input array.
252+
xp : array_namespace, optional
253+
The array namespace for the input array. Default: infer.
254254
255255
Returns
256256
-------
257257
Array
258258
Updated `x`.
259259
"""
260+
from ._funcs import apply_where
261+
260262
x, idx = self._x, self._idx
261263
xp = array_namespace(x, y) if xp is None else xp
262264

@@ -294,8 +296,10 @@ def _op(
294296
y_xp = xp.asarray(y, dtype=x.dtype)
295297
if y_xp.ndim == 0:
296298
if out_of_place_op:
297-
# FIXME: suppress inf warnings on dask with lazywhere
298-
out = xp.where(idx, out_of_place_op(x, y_xp), x)
299+
# suppress inf warnings on Dask
300+
out = apply_where(
301+
idx, out_of_place_op, (x, y_xp), fill_value=x, xp=xp
302+
)
299303
# Undo int->float promotion on JAX after _AtOp.DIVIDE
300304
out = xp.astype(out, x.dtype, copy=False)
301305
else:

src/array_api_extra/_lib/_funcs.py

+168-8
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,23 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
8+
from collections.abc import Callable, Sequence
99
from types import ModuleType
10-
from typing import cast
10+
from typing import cast, overload
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays
14+
from ._utils._compat import (
15+
array_namespace,
16+
is_dask_namespace,
17+
is_jax_array,
18+
is_jax_namespace,
19+
)
20+
from ._utils._helpers import asarrays, meta_namespace
1621
from ._utils._typing import Array
1722

1823
__all__ = [
24+
"apply_where",
1925
"atleast_nd",
2026
"broadcast_shapes",
2127
"cov",
@@ -29,6 +35,157 @@
2935
]
3036

3137

38+
@overload
39+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
40+
cond: Array,
41+
f1: Callable[..., Array],
42+
f2: Callable[..., Array],
43+
args: Array | tuple[Array, ...],
44+
/,
45+
*,
46+
xp: ModuleType | None = None,
47+
) -> Array: ...
48+
49+
50+
@overload
51+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
52+
cond: Array,
53+
f1: Callable[..., Array],
54+
args: Array | tuple[Array, ...],
55+
/,
56+
*,
57+
fill_value: Array | int | float | complex | bool,
58+
xp: ModuleType | None = None,
59+
) -> Array: ...
60+
61+
62+
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
63+
cond: Array,
64+
f1: Callable[..., Array],
65+
f2: Callable[..., Array] | Array | tuple[Array], # optional positional argument
66+
args: Array | tuple[Array, ...] | None = None,
67+
/,
68+
*,
69+
fill_value: Array | int | float | complex | bool | None = None,
70+
xp: ModuleType | None = None,
71+
) -> Array:
72+
"""
73+
Run one of two elementwise functions depending on a condition.
74+
75+
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise
76+
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``.
77+
78+
Parameters
79+
----------
80+
cond : array
81+
The condition, expressed as a boolean array.
82+
f1 : callable
83+
Elementwise function of `args`, returning a single array.
84+
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
85+
f2 : callable, optional
86+
Elementwise function of `args`, returning a single array.
87+
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
88+
Mutually exclusive with `fill_value`.
89+
args : Array or tuple of Arrays
90+
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
91+
fill_value : Array or scalar, optional
92+
If provided, value with which to fill output array where `cond` is False.
93+
It does not need to be scalar.
94+
Mutually exclusive with `f2`. You must provide one or the other.
95+
xp : array_namespace, optional
96+
The standard-compatible namespace for `cond` and `args`. Default: infer.
97+
98+
Returns
99+
-------
100+
Array
101+
An array with elements from the output of `f1` where `cond` is True and either
102+
the output of `f2` or `fill_value` where `cond` is False. The returned array has
103+
data type determined by type promotion rules between the output of `f1` and
104+
either `fill_value` or the output of `f2`.
105+
106+
Notes
107+
-----
108+
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even
109+
when `cond` is False, and `f2` when cond is True. This function evaluates each
110+
function only for their matching condition, if the backend allows for it.
111+
112+
On Dask, `f1` and `f2` are applied to the individual chunks and should use functions
113+
from the namespace of the chunks.
114+
115+
Examples
116+
--------
117+
>>> a = xp.asarray([5, 4, 3])
118+
>>> b = xp.asarray([0, 2, 2])
119+
>>> def f(a, b):
120+
... return a // b
121+
>>> apply_where(b != 0, f, (a, b), fill_value=xp.nan)
122+
array([ nan, 2., 1.])
123+
"""
124+
# Parse and normalize arguments
125+
mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
126+
if args is None:
127+
f2, args = None, f2
128+
if fill_value is None:
129+
raise TypeError(mutually_exc_msg)
130+
else:
131+
if not callable(f2):
132+
msg = "Third parameter must be a callable, Array, or tuple of Arrays."
133+
raise TypeError(msg)
134+
if fill_value is not None:
135+
raise TypeError(mutually_exc_msg)
136+
137+
if not isinstance(args, tuple):
138+
args = (args,)
139+
f2 = cast(Callable[..., Array] | None, f2) # type: ignore[no-any-explicit]
140+
args = cast(tuple[Array, ...], args)
141+
142+
xp = array_namespace(cond, *args) if xp is None else xp
143+
144+
if getattr(fill_value, "ndim", 0):
145+
cond, fill_value, *args = xp.broadcast_arrays(cond, fill_value, *args)
146+
else:
147+
cond, *args = xp.broadcast_arrays(cond, *args)
148+
149+
if is_dask_namespace(xp):
150+
meta_xp = meta_namespace(cond, fill_value, *args, xp=xp)
151+
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args, xp=meta_xp)
152+
return _apply_where(cond, f1, f2, fill_value, *args, xp=xp)
153+
154+
155+
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
156+
cond: Array,
157+
f1: Callable[..., Array],
158+
f2: Callable[..., Array] | None,
159+
fill_value: Array | int | float | complex | bool | None,
160+
*args: Array,
161+
xp: ModuleType,
162+
) -> Array:
163+
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
164+
165+
if is_jax_namespace(xp):
166+
# jax.jit does not support assignment by boolean mask
167+
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
168+
169+
device = _compat.device(cond)
170+
temp1 = f1(*(arr[cond] for arr in args))
171+
172+
if f2 is None:
173+
# TODO remove asarrays once all backends support Array API 2024.12
174+
dtype = xp.result_type(*asarrays(temp1, fill_value, xp=xp))
175+
if getattr(fill_value, "ndim", 0):
176+
fill_value = xp.astype(fill_value, dtype)
177+
return at(fill_value, cond).set(temp1, copy=True)
178+
out = xp.full(cond.shape, fill_value=fill_value, dtype=dtype, device=device)
179+
else:
180+
ncond = ~cond
181+
temp2 = f2(*(arr[ncond] for arr in args))
182+
dtype = xp.result_type(temp1, temp2)
183+
out = xp.empty(cond.shape, dtype=dtype, device=device)
184+
out = at(out, ncond).set(temp2)
185+
186+
return at(out, cond).set(temp1)
187+
188+
32189
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
33190
"""
34191
Recursively expand the dimension of an array to at least `ndim`.
@@ -385,12 +542,15 @@ def isclose(
385542
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
386543
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
387544
if a_inexact or b_inexact:
388-
# FIXME: use scipy's lazywhere to suppress warnings on inf
389-
out = xp.where(
545+
# prevent warnings on numpy and dask on inf - inf
546+
mxp = meta_namespace(a, b, xp=xp)
547+
out = apply_where(
390548
xp.isinf(a) | xp.isinf(b),
391-
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)),
549+
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType]
392550
# Note: inf <= inf is True!
393-
xp.abs(a - b) <= (atol + rtol * xp.abs(b)),
551+
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType]
552+
(a, b),
553+
xp=xp,
394554
)
395555
if equal_nan:
396556
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)

src/array_api_extra/_lib/_utils/_helpers.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77
from typing import cast
88

99
from . import _compat
10-
from ._compat import is_array_api_obj, is_numpy_array
10+
from ._compat import (
11+
array_namespace,
12+
is_array_api_obj,
13+
is_dask_namespace,
14+
is_numpy_array,
15+
)
1116
from ._typing import Array
1217

13-
__all__ = ["in1d", "mean"]
18+
__all__ = ["asarrays", "in1d", "is_python_scalar", "mean", "meta_namespace"]
1419

1520

1621
def in1d(
@@ -33,7 +38,7 @@ def in1d(
3338
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
3439
"""
3540
if xp is None:
36-
xp = _compat.array_namespace(x1, x2)
41+
xp = array_namespace(x1, x2)
3742

3843
# This code is run to make the code significantly faster
3944
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
@@ -84,7 +89,7 @@ def mean(
8489
Complex mean, https://github.com/data-apis/array-api/issues/846.
8590
"""
8691
if xp is None:
87-
xp = _compat.array_namespace(x)
92+
xp = array_namespace(x)
8893

8994
if xp.isdtype(x.dtype, "complex floating"):
9095
x_real = xp.real(x)
@@ -124,8 +129,8 @@ def asarrays(
124129
----------
125130
a, b : Array | int | float | complex | bool
126131
Input arrays or scalars. At least one must be an array.
127-
xp : ModuleType
128-
The standard-compatible namespace for the returned arrays.
132+
xp : array_namespace, optional
133+
The standard-compatible namespace for `x`. Default: infer.
129134
130135
Returns
131136
-------
@@ -175,3 +180,33 @@ def asarrays(
175180
xa, xb = xp.asarray(a), xp.asarray(b)
176181

177182
return (xb, xa) if swap else (xa, xb)
183+
184+
185+
def meta_namespace(
186+
*arrays: Array | int | float | complex | bool | None,
187+
xp: ModuleType | None = None,
188+
) -> ModuleType:
189+
"""
190+
Get the namespace of Dask chunks.
191+
192+
On all other backends, just return the namespace of the arrays.
193+
194+
Parameters
195+
----------
196+
*arrays : Array | int | float | complex | bool | None
197+
Input arrays.
198+
xp : array_namespace, optional
199+
The standard-compatible namespace for `x`. Default: infer.
200+
201+
Returns
202+
-------
203+
array_namespace
204+
If xp is Dask, the namespace of the Dask chunks;
205+
otherwise, the namespace of the arrays.
206+
"""
207+
xp = array_namespace(*arrays) if xp is None else xp
208+
if not is_dask_namespace(xp):
209+
return xp
210+
# Quietly skip scalars and None's
211+
metas = [getattr(a, "_meta", None) for a in arrays]
212+
return array_namespace(*metas)

src/array_api_extra/_lib/_utils/_typing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
# To be changed to a Protocol later (see data-apis/array-api#589)
66
Array = Any # type: ignore[no-any-explicit]
7+
DType = Any # type: ignore[no-any-explicit]
78
Device = Any # type: ignore[no-any-explicit]
89
Index = Any # type: ignore[no-any-explicit]
910

10-
__all__ = ["Array", "Device", "Index"]
11+
__all__ = ["Array", "DType", "Device", "Index"]

0 commit comments

Comments
 (0)