Skip to content

Commit 41e26ad

Browse files
committed
WIP lazywhere
1 parent 27b0bf2 commit 41e26ad

File tree

5 files changed

+206
-9
lines changed

5 files changed

+206
-9
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
apply_where
910
at
1011
atleast_nd
1112
cov

src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
6+
apply_where,
67
atleast_nd,
78
cov,
89
create_diagonal,
@@ -18,6 +19,7 @@
1819
# pylint: disable=duplicate-code
1920
__all__ = [
2021
"__version__",
22+
"apply_where",
2123
"at",
2224
"atleast_nd",
2325
"cov",

src/array_api_extra/_lib/_funcs.py

+175-5
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,25 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
8+
from collections.abc import Callable, Sequence
9+
from functools import partial
910
from types import ModuleType
10-
from typing import cast
11+
from typing import cast, overload
1112

1213
from ._at import at
1314
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays
16-
from ._utils._typing import Array
15+
from ._utils._compat import (
16+
array_namespace,
17+
is_array_api_obj,
18+
is_dask_namespace,
19+
is_jax_array,
20+
is_jax_namespace,
21+
)
22+
from ._utils._helpers import asarrays, get_meta
23+
from ._utils._typing import Array, DType
1724

1825
__all__ = [
26+
"apply_where",
1927
"atleast_nd",
2028
"cov",
2129
"create_diagonal",
@@ -28,6 +36,168 @@
2836
]
2937

3038

39+
@overload
40+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
41+
cond: Array,
42+
f1: Callable[..., Array],
43+
f2: Callable[..., Array],
44+
/,
45+
*args: Array,
46+
xp: ModuleType | None = None,
47+
) -> Array: ...
48+
49+
50+
@overload
51+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
52+
cond: Array,
53+
f1: Callable[..., Array],
54+
/,
55+
*args: Array,
56+
fill_value: Array | int | float | complex | bool,
57+
xp: ModuleType | None = None,
58+
) -> Array: ...
59+
60+
61+
def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR02
62+
cond: Array,
63+
f1: Callable[..., Array],
64+
f2: Callable[..., Array] | Array,
65+
/,
66+
*args: Array,
67+
fill_value: Array | int | float | complex | bool | None = None,
68+
xp: ModuleType | None = None,
69+
) -> Array:
70+
"""
71+
Run one of two elementwise functions depending on a condition.
72+
73+
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise
74+
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``.
75+
76+
Parameters
77+
----------
78+
cond : array
79+
The condition, expressed as a boolean array.
80+
f1 : callable
81+
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
82+
f2 : callable, optional
83+
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
84+
Mutually exclusive with `fill_value`.
85+
*args : one or more arrays
86+
Arguments to `f1` (and `f2`). Must be broadcastable with `cond`.
87+
fill_value : Array or scalar, optional
88+
If provided, value with which to fill output array where `cond` is
89+
not True. Mutually exclusive with `f2`. You must provide one or the other.
90+
xp : array_namespace, optional
91+
The standard-compatible namespace for `cond` and `args`. Default: infer.
92+
93+
Returns
94+
-------
95+
Array
96+
An array with elements from the output of `f1` where `cond` is True and either
97+
the output of `f2` or `fill_value` where `cond` is False. The returned array has
98+
data type determined by type promotion rules between the output of `f1` and
99+
either `fill_value` or the output of `f2`.
100+
101+
Notes
102+
-----
103+
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even
104+
when `cond` is False, and `f2` when cond is True. This function evaluates each
105+
function only for their matching condition, if the backend allows for it.
106+
107+
Examples
108+
--------
109+
>>> a = xp.asarray([5, 4, 3])
110+
>>> b = xp.asarray([0, 2, 2])
111+
>>> def f(a, b):
112+
... return a // b
113+
>>> apply_where(b != 0, f, a, b, fill_value=xp.nan)
114+
array([ nan, 2., 1.])
115+
"""
116+
# Parse and normalize arguments
117+
mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
118+
if is_array_api_obj(f2):
119+
args = (cast(Array, f2), *args)
120+
if fill_value is not None:
121+
raise TypeError(mutually_exc_msg)
122+
f2_: Callable[..., Array] | None = None # type: ignore[no-any-explicit]
123+
else:
124+
if not callable(f2):
125+
msg = "Third parameter must be either an Array or callable."
126+
raise ValueError(msg)
127+
f2_ = cast(Callable[..., Array], f2) # type: ignore[no-any-explicit]
128+
if fill_value is None:
129+
raise TypeError(mutually_exc_msg)
130+
if getattr(fill_value, "ndim", 0) != 0:
131+
msg = "`fill_value` must be a scalar."
132+
raise ValueError(msg)
133+
del f2
134+
if not args:
135+
msg = "Must give at least one input array."
136+
raise TypeError(msg)
137+
138+
xp = array_namespace(cond, *args) if xp is None else xp
139+
140+
# Determine output dtype
141+
metas = [get_meta(arg, xp=xp) for arg in args]
142+
temp1 = f1(*metas)
143+
if f2_ is None:
144+
if xp.__array_api_version__ >= "2024.12" or is_array_api_obj(fill_value):
145+
dtype = xp.result_type(temp1.dtype, fill_value)
146+
else:
147+
# TODO: remove this when all backends support Array API 2024.12
148+
dtype = (xp.empty((), dtype=temp1.dtype) * fill_value).dtype
149+
else:
150+
temp2 = f2_(*metas)
151+
dtype = xp.result_type(temp1, temp2)
152+
153+
if is_dask_namespace(xp):
154+
# Dask does not support assignment by boolean mask
155+
meta_xp = array_namespace(get_meta(cond), *metas)
156+
# pass dtype to both da.map_blocks and _apply_where
157+
return xp.map_blocks(
158+
partial(_apply_where, dtype=dtype, xp=meta_xp),
159+
cond,
160+
f1,
161+
f2_,
162+
*args,
163+
fill_value=fill_value,
164+
dtype=dtype,
165+
meta=metas[0],
166+
)
167+
168+
return _apply_where(cond, f1, f2_, *args, fill_value=fill_value, dtype=dtype, xp=xp)
169+
170+
171+
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
172+
cond: Array,
173+
f1: Callable[..., Array],
174+
f2: Callable[..., Array] | None,
175+
*args: Array,
176+
fill_value: Array | int | float | complex | bool | None,
177+
dtype: DType,
178+
xp: ModuleType,
179+
) -> Array:
180+
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
181+
182+
if is_jax_namespace(xp):
183+
# jax.jit does not support assignment by boolean mask
184+
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
185+
186+
device = _compat.device(cond)
187+
cond, *args = xp.broadcast_arrays(cond, *args) # pyright: ignore[reportAssignmentType]
188+
temp1 = f1(*(arr[cond] for arr in args))
189+
190+
if f2 is None:
191+
out = xp.full(cond.shape, fill_value=fill_value, dtype=dtype, device=device)
192+
else:
193+
ncond = ~cond
194+
temp2 = f2(*(arr[ncond] for arr in args))
195+
out = xp.empty(cond.shape, dtype=dtype, device=device)
196+
out = at(out, ncond).set(temp2)
197+
198+
return at(out, cond).set(temp1)
199+
200+
31201
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
32202
"""
33203
Recursively expand the dimension of an array to at least `ndim`.

src/array_api_extra/_lib/_utils/_helpers.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
from typing import cast
88

99
from . import _compat
10-
from ._compat import is_array_api_obj, is_numpy_array
10+
from ._compat import array_namespace, is_array_api_obj, is_dask_array, is_numpy_array
1111
from ._typing import Array
1212

13-
__all__ = ["in1d", "mean"]
14-
1513

1614
def in1d(
1715
x1: Array,
@@ -175,3 +173,28 @@ def asarrays(
175173
xa, xb = xp.asarray(a), xp.asarray(b)
176174

177175
return (xb, xa) if swap else (xa, xb)
176+
177+
178+
def get_meta(x: Array, xp: ModuleType | None = None) -> Array:
179+
"""
180+
Return a 0-sized dummy array that mocks `x`.
181+
182+
Parameters
183+
----------
184+
x : Array
185+
The array to mock.
186+
xp : ModuleType, optional
187+
The array namespace to use. If None, it is inferred from `x`.
188+
189+
Returns
190+
-------
191+
Array
192+
Array with size 0 with the same same namespace, dimensionality,
193+
dtype and device as `x`.
194+
On Dask, return instead the meta array of `x`, which has the
195+
namespace of the wrapped backend.
196+
"""
197+
if is_dask_array(x):
198+
return x._meta # pylint: disable=protected-access
199+
xp = array_namespace(x) if xp is None else xp
200+
return xp.empty((0,) * x.ndim, dtype=x.dtype, device=_compat.device(x))

src/array_api_extra/_lib/_utils/_typing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# To be changed to a Protocol later (see data-apis/array-api#589)
66
Array = Any # type: ignore[no-any-explicit]
77
Device = Any # type: ignore[no-any-explicit]
8+
DType = Any # type: ignore[no-any-explicit]
89
Index = Any # type: ignore[no-any-explicit]
910

10-
__all__ = ["Array", "Device", "Index"]
11+
__all__ = ["Array", "DType", "Device", "Index"]

0 commit comments

Comments
 (0)