|
5 | 5 |
|
6 | 6 | import math
|
7 | 7 | import warnings
|
8 |
| -from collections.abc import Sequence |
| 8 | +from collections.abc import Callable, Sequence |
| 9 | +from functools import partial |
9 | 10 | from types import ModuleType
|
10 |
| -from typing import cast |
| 11 | +from typing import cast, overload |
11 | 12 |
|
12 | 13 | from ._at import at
|
13 | 14 | from ._utils import _compat, _helpers
|
14 |
| -from ._utils._compat import array_namespace, is_jax_array |
15 |
| -from ._utils._helpers import asarrays |
16 |
| -from ._utils._typing import Array |
| 15 | +from ._utils._compat import ( |
| 16 | + array_namespace, |
| 17 | + is_array_api_obj, |
| 18 | + is_dask_namespace, |
| 19 | + is_jax_array, |
| 20 | + is_jax_namespace, |
| 21 | +) |
| 22 | +from ._utils._helpers import asarrays, get_meta |
| 23 | +from ._utils._typing import Array, DType |
17 | 24 |
|
18 | 25 | __all__ = [
|
| 26 | + "apply_where", |
19 | 27 | "atleast_nd",
|
20 | 28 | "cov",
|
21 | 29 | "create_diagonal",
|
|
28 | 36 | ]
|
29 | 37 |
|
30 | 38 |
|
| 39 | +@overload |
| 40 | +def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 |
| 41 | + cond: Array, |
| 42 | + f1: Callable[..., Array], |
| 43 | + f2: Callable[..., Array], |
| 44 | + /, |
| 45 | + *args: Array, |
| 46 | + xp: ModuleType | None = None, |
| 47 | +) -> Array: ... |
| 48 | + |
| 49 | + |
| 50 | +@overload |
| 51 | +def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 |
| 52 | + cond: Array, |
| 53 | + f1: Callable[..., Array], |
| 54 | + /, |
| 55 | + *args: Array, |
| 56 | + fill_value: Array | int | float | complex | bool, |
| 57 | + xp: ModuleType | None = None, |
| 58 | +) -> Array: ... |
| 59 | + |
| 60 | + |
| 61 | +def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR02 |
| 62 | + cond: Array, |
| 63 | + f1: Callable[..., Array], |
| 64 | + f2: Callable[..., Array] | Array, |
| 65 | + /, |
| 66 | + *args: Array, |
| 67 | + fill_value: Array | int | float | complex | bool | None = None, |
| 68 | + xp: ModuleType | None = None, |
| 69 | +) -> Array: |
| 70 | + """ |
| 71 | + Run one of two elementwise functions depending on a condition. |
| 72 | +
|
| 73 | + Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise |
| 74 | + when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``. |
| 75 | +
|
| 76 | + Parameters |
| 77 | + ---------- |
| 78 | + cond : array |
| 79 | + The condition, expressed as a boolean array. |
| 80 | + f1 : callable |
| 81 | + Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``. |
| 82 | + f2 : callable, optional |
| 83 | + Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``. |
| 84 | + Mutually exclusive with `fill_value`. |
| 85 | + *args : one or more arrays |
| 86 | + Arguments to `f1` (and `f2`). Must be broadcastable with `cond`. |
| 87 | + fill_value : Array or scalar, optional |
| 88 | + If provided, value with which to fill output array where `cond` is |
| 89 | + not True. Mutually exclusive with `f2`. You must provide one or the other. |
| 90 | + xp : array_namespace, optional |
| 91 | + The standard-compatible namespace for `cond` and `args`. Default: infer. |
| 92 | +
|
| 93 | + Returns |
| 94 | + ------- |
| 95 | + Array |
| 96 | + An array with elements from the output of `f1` where `cond` is True and either |
| 97 | + the output of `f2` or `fill_value` where `cond` is False. The returned array has |
| 98 | + data type determined by type promotion rules between the output of `f1` and |
| 99 | + either `fill_value` or the output of `f2`. |
| 100 | +
|
| 101 | + Notes |
| 102 | + ----- |
| 103 | + ``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even |
| 104 | + when `cond` is False, and `f2` when cond is True. This function evaluates each |
| 105 | + function only for their matching condition, if the backend allows for it. |
| 106 | +
|
| 107 | + Examples |
| 108 | + -------- |
| 109 | + >>> a = xp.asarray([5, 4, 3]) |
| 110 | + >>> b = xp.asarray([0, 2, 2]) |
| 111 | + >>> def f(a, b): |
| 112 | + ... return a // b |
| 113 | + >>> apply_where(b != 0, f, a, b, fill_value=xp.nan) |
| 114 | + array([ nan, 2., 1.]) |
| 115 | + """ |
| 116 | + # Parse and normalize arguments |
| 117 | + mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given." |
| 118 | + if is_array_api_obj(f2): |
| 119 | + args = (cast(Array, f2), *args) |
| 120 | + if fill_value is not None: |
| 121 | + raise TypeError(mutually_exc_msg) |
| 122 | + f2_: Callable[..., Array] | None = None # type: ignore[no-any-explicit] |
| 123 | + else: |
| 124 | + if not callable(f2): |
| 125 | + msg = "Third parameter must be either an Array or callable." |
| 126 | + raise ValueError(msg) |
| 127 | + f2_ = cast(Callable[..., Array], f2) # type: ignore[no-any-explicit] |
| 128 | + if fill_value is None: |
| 129 | + raise TypeError(mutually_exc_msg) |
| 130 | + if getattr(fill_value, "ndim", 0) != 0: |
| 131 | + msg = "`fill_value` must be a scalar." |
| 132 | + raise ValueError(msg) |
| 133 | + del f2 |
| 134 | + if not args: |
| 135 | + msg = "Must give at least one input array." |
| 136 | + raise TypeError(msg) |
| 137 | + |
| 138 | + xp = array_namespace(cond, *args) if xp is None else xp |
| 139 | + |
| 140 | + # Determine output dtype |
| 141 | + metas = [get_meta(arg, xp=xp) for arg in args] |
| 142 | + temp1 = f1(*metas) |
| 143 | + if f2_ is None: |
| 144 | + if xp.__array_api_version__ >= "2024.12" or is_array_api_obj(fill_value): |
| 145 | + dtype = xp.result_type(temp1.dtype, fill_value) |
| 146 | + else: |
| 147 | + # TODO: remove this when all backends support Array API 2024.12 |
| 148 | + dtype = (xp.empty((), dtype=temp1.dtype) * fill_value).dtype |
| 149 | + else: |
| 150 | + temp2 = f2_(*metas) |
| 151 | + dtype = xp.result_type(temp1, temp2) |
| 152 | + |
| 153 | + if is_dask_namespace(xp): |
| 154 | + # Dask does not support assignment by boolean mask |
| 155 | + meta_xp = array_namespace(get_meta(cond), *metas) |
| 156 | + # pass dtype to both da.map_blocks and _apply_where |
| 157 | + return xp.map_blocks( |
| 158 | + partial(_apply_where, dtype=dtype, xp=meta_xp), |
| 159 | + cond, |
| 160 | + f1, |
| 161 | + f2_, |
| 162 | + *args, |
| 163 | + fill_value=fill_value, |
| 164 | + dtype=dtype, |
| 165 | + meta=metas[0], |
| 166 | + ) |
| 167 | + |
| 168 | + return _apply_where(cond, f1, f2_, *args, fill_value=fill_value, dtype=dtype, xp=xp) |
| 169 | + |
| 170 | + |
| 171 | +def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 |
| 172 | + cond: Array, |
| 173 | + f1: Callable[..., Array], |
| 174 | + f2: Callable[..., Array] | None, |
| 175 | + *args: Array, |
| 176 | + fill_value: Array | int | float | complex | bool | None, |
| 177 | + dtype: DType, |
| 178 | + xp: ModuleType, |
| 179 | +) -> Array: |
| 180 | + """Helper of `apply_where`. On Dask, this runs on a single chunk.""" |
| 181 | + |
| 182 | + if is_jax_namespace(xp): |
| 183 | + # jax.jit does not support assignment by boolean mask |
| 184 | + return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value) |
| 185 | + |
| 186 | + device = _compat.device(cond) |
| 187 | + cond, *args = xp.broadcast_arrays(cond, *args) # pyright: ignore[reportAssignmentType] |
| 188 | + temp1 = f1(*(arr[cond] for arr in args)) |
| 189 | + |
| 190 | + if f2 is None: |
| 191 | + out = xp.full(cond.shape, fill_value=fill_value, dtype=dtype, device=device) |
| 192 | + else: |
| 193 | + ncond = ~cond |
| 194 | + temp2 = f2(*(arr[ncond] for arr in args)) |
| 195 | + out = xp.empty(cond.shape, dtype=dtype, device=device) |
| 196 | + out = at(out, ncond).set(temp2) |
| 197 | + |
| 198 | + return at(out, cond).set(temp1) |
| 199 | + |
| 200 | + |
31 | 201 | def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
|
32 | 202 | """
|
33 | 203 | Recursively expand the dimension of an array to at least `ndim`.
|
|
0 commit comments