Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2d199dd

Browse files
committedFeb 18, 2025·
ENH: apply_where (from scipy's lazywhere)
1 parent 573ed3c commit 2d199dd

File tree

13 files changed

+443
-45
lines changed

13 files changed

+443
-45
lines changed
 

‎docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
apply_where
910
at
1011
atleast_nd
1112
broadcast_shapes

‎pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

+5-4
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,6 @@ markers = [
191191

192192
[tool.coverage]
193193
run.source = ["array_api_extra"]
194-
report.exclude_also = [
195-
'\.\.\.',
196-
'if TYPE_CHECKING:',
197-
]
198194

199195
# mypy
200196

@@ -236,6 +232,10 @@ reportMissingTypeStubs = false
236232
reportUnreachable = false
237233
# ruff handles this
238234
reportUnusedParameter = false
235+
# cyclic imports inside function bodies
236+
reportImportCycles = false
237+
# PyRight can't trace types in lambdas
238+
reportUnknownLambdaType = false
239239

240240
executionEnvironments = [
241241
{ root = "tests", reportPrivateUsage = false },
@@ -301,6 +301,7 @@ messages_control.disable = [
301301
"missing-function-docstring", # numpydoc handles this
302302
"import-error", # mypy handles this
303303
"import-outside-toplevel", # optional dependencies
304+
"cyclic-import", # cyclic imports inside function bodies
304305
]
305306

306307

‎src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
6+
apply_where,
67
atleast_nd,
78
broadcast_shapes,
89
cov,
@@ -19,6 +20,7 @@
1920
# pylint: disable=duplicate-code
2021
__all__ = [
2122
"__version__",
23+
"apply_where",
2224
"at",
2325
"atleast_nd",
2426
"broadcast_shapes",

‎src/array_api_extra/_lib/_at.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,16 @@ def _op(
249249
Right-hand side of the operation.
250250
copy : bool or None
251251
Whether to copy the input array. See the class docstring for details.
252-
xp : array_namespace or None
253-
The array namespace for the input array.
252+
xp : array_namespace, optional
253+
The array namespace for the input array. Default: infer.
254254
255255
Returns
256256
-------
257257
Array
258258
Updated `x`.
259259
"""
260+
from ._funcs import apply_where
261+
260262
x, idx = self._x, self._idx
261263
xp = array_namespace(x, y) if xp is None else xp
262264

@@ -294,8 +296,10 @@ def _op(
294296
y_xp = xp.asarray(y, dtype=x.dtype)
295297
if y_xp.ndim == 0:
296298
if out_of_place_op:
297-
# FIXME: suppress inf warnings on dask with lazywhere
298-
out = xp.where(idx, out_of_place_op(x, y_xp), x)
299+
# suppress inf warnings on Dask
300+
out = apply_where(
301+
idx, out_of_place_op, (x, y_xp), fill_value=x, xp=xp
302+
)
299303
# Undo int->float promotion on JAX after _AtOp.DIVIDE
300304
out = xp.astype(out, x.dtype, copy=False)
301305
else:

‎src/array_api_extra/_lib/_funcs.py

+167-8
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,23 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
8+
from collections.abc import Callable, Sequence
99
from types import ModuleType
10-
from typing import cast
10+
from typing import cast, overload
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays
14+
from ._utils._compat import (
15+
array_namespace,
16+
is_dask_namespace,
17+
is_jax_array,
18+
is_jax_namespace,
19+
)
20+
from ._utils._helpers import asarrays, meta_namespace
1621
from ._utils._typing import Array
1722

1823
__all__ = [
24+
"apply_where",
1925
"atleast_nd",
2026
"broadcast_shapes",
2127
"cov",
@@ -29,6 +35,156 @@
2935
]
3036

3137

38+
@overload
39+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
40+
cond: Array,
41+
f1: Callable[..., Array],
42+
f2: Callable[..., Array],
43+
args: Array | tuple[Array, ...],
44+
/,
45+
*,
46+
xp: ModuleType | None = None,
47+
) -> Array: ...
48+
49+
50+
@overload
51+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
52+
cond: Array,
53+
f1: Callable[..., Array],
54+
args: Array | tuple[Array, ...],
55+
/,
56+
*,
57+
fill_value: Array | int | float | complex | bool,
58+
xp: ModuleType | None = None,
59+
) -> Array: ...
60+
61+
62+
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
63+
cond: Array,
64+
f1: Callable[..., Array],
65+
f2: Callable[..., Array] | Array | tuple[Array], # optional positional argument
66+
args: Array | tuple[Array, ...] | None = None,
67+
/,
68+
*,
69+
fill_value: Array | int | float | complex | bool | None = None,
70+
xp: ModuleType | None = None,
71+
) -> Array:
72+
"""
73+
Run one of two elementwise functions depending on a condition.
74+
75+
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise
76+
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``.
77+
78+
Parameters
79+
----------
80+
cond : array
81+
The condition, expressed as a boolean array.
82+
f1 : callable
83+
Elementwise function of `args`, returning a single array.
84+
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
85+
f2 : callable, optional
86+
Elementwise function of `args`, returning a single array.
87+
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
88+
Mutually exclusive with `fill_value`.
89+
args : Array or tuple of Arrays
90+
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
91+
fill_value : Array or scalar, optional
92+
If provided, value with which to fill output array where `cond` is False.
93+
It does not need to be scalar.
94+
Mutually exclusive with `f2`. You must provide one or the other.
95+
xp : array_namespace, optional
96+
The standard-compatible namespace for `cond` and `args`. Default: infer.
97+
98+
Returns
99+
-------
100+
Array
101+
An array with elements from the output of `f1` where `cond` is True and either
102+
the output of `f2` or `fill_value` where `cond` is False. The returned array has
103+
data type determined by type promotion rules between the output of `f1` and
104+
either `fill_value` or the output of `f2`.
105+
106+
Notes
107+
-----
108+
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even
109+
when `cond` is False, and `f2` when cond is True. This function evaluates each
110+
function only for their matching condition, if the backend allows for it.
111+
112+
On Dask, f1 and f2 are applied to the individual chunks.
113+
114+
Examples
115+
--------
116+
>>> a = xp.asarray([5, 4, 3])
117+
>>> b = xp.asarray([0, 2, 2])
118+
>>> def f(a, b):
119+
... return a // b
120+
>>> apply_where(b != 0, f, (a, b), fill_value=xp.nan)
121+
array([ nan, 2., 1.])
122+
"""
123+
# Parse and normalize arguments
124+
mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
125+
if args is None:
126+
f2, args = None, f2
127+
if fill_value is None:
128+
raise TypeError(mutually_exc_msg)
129+
else:
130+
if not callable(f2):
131+
msg = "Third parameter must be a callable, Array, or tuple of Arrays."
132+
raise TypeError(msg)
133+
if fill_value is not None:
134+
raise TypeError(mutually_exc_msg)
135+
136+
if not isinstance(args, tuple):
137+
args = (args,)
138+
f2 = cast(Callable[..., Array] | None, f2) # type: ignore[no-any-explicit]
139+
args = cast(tuple[Array, ...], args)
140+
141+
xp = array_namespace(cond, *args) if xp is None else xp
142+
143+
if getattr(fill_value, "ndim", 0):
144+
cond, fill_value, *args = xp.broadcast_arrays(cond, fill_value, *args)
145+
else:
146+
cond, *args = xp.broadcast_arrays(cond, *args)
147+
148+
if is_dask_namespace(xp):
149+
meta_xp = meta_namespace(cond, fill_value, *args, xp=xp)
150+
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args, xp=meta_xp)
151+
return _apply_where(cond, f1, f2, fill_value, *args, xp=xp)
152+
153+
154+
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
155+
cond: Array,
156+
f1: Callable[..., Array],
157+
f2: Callable[..., Array] | None,
158+
fill_value: Array | int | float | complex | bool | None,
159+
*args: Array,
160+
xp: ModuleType,
161+
) -> Array:
162+
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
163+
164+
if is_jax_namespace(xp):
165+
# jax.jit does not support assignment by boolean mask
166+
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
167+
168+
device = _compat.device(cond)
169+
temp1 = f1(*(arr[cond] for arr in args))
170+
171+
if f2 is None:
172+
# TODO remove asarrays once all backends support Array API 2024.12
173+
dtype = xp.result_type(*asarrays(temp1, fill_value, xp=xp))
174+
if getattr(fill_value, "ndim", 0):
175+
fill_value = xp.astype(fill_value, dtype)
176+
return at(fill_value, cond).set(temp1, copy=True)
177+
out = xp.full(cond.shape, fill_value=fill_value, dtype=dtype, device=device)
178+
else:
179+
ncond = ~cond
180+
temp2 = f2(*(arr[ncond] for arr in args))
181+
dtype = xp.result_type(temp1, temp2)
182+
out = xp.empty(cond.shape, dtype=dtype, device=device)
183+
out = at(out, ncond).set(temp2)
184+
185+
return at(out, cond).set(temp1)
186+
187+
32188
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
33189
"""
34190
Recursively expand the dimension of an array to at least `ndim`.
@@ -385,12 +541,15 @@ def isclose(
385541
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
386542
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
387543
if a_inexact or b_inexact:
388-
# FIXME: use scipy's lazywhere to suppress warnings on inf
389-
out = xp.where(
544+
# prevent warnings on numpy and dask on inf - inf
545+
mxp = meta_namespace(a, b, xp=xp)
546+
out = apply_where(
390547
xp.isinf(a) | xp.isinf(b),
391-
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)),
548+
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType]
392549
# Note: inf <= inf is True!
393-
xp.abs(a - b) <= (atol + rtol * xp.abs(b)),
550+
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType]
551+
(a, b),
552+
xp=xp,
394553
)
395554
if equal_nan:
396555
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)

‎src/array_api_extra/_lib/_utils/_helpers.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77
from typing import cast
88

99
from . import _compat
10-
from ._compat import is_array_api_obj, is_numpy_array
10+
from ._compat import (
11+
array_namespace,
12+
is_array_api_obj,
13+
is_dask_namespace,
14+
is_numpy_array,
15+
)
1116
from ._typing import Array
1217

13-
__all__ = ["in1d", "mean"]
18+
__all__ = ["asarrays", "in1d", "is_python_scalar", "mean", "meta_namespace"]
1419

1520

1621
def in1d(
@@ -33,7 +38,7 @@ def in1d(
3338
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
3439
"""
3540
if xp is None:
36-
xp = _compat.array_namespace(x1, x2)
41+
xp = array_namespace(x1, x2)
3742

3843
# This code is run to make the code significantly faster
3944
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
@@ -84,7 +89,7 @@ def mean(
8489
Complex mean, https://github.com/data-apis/array-api/issues/846.
8590
"""
8691
if xp is None:
87-
xp = _compat.array_namespace(x)
92+
xp = array_namespace(x)
8893

8994
if xp.isdtype(x.dtype, "complex floating"):
9095
x_real = xp.real(x)
@@ -124,8 +129,8 @@ def asarrays(
124129
----------
125130
a, b : Array | int | float | complex | bool
126131
Input arrays or scalars. At least one must be an array.
127-
xp : ModuleType
128-
The standard-compatible namespace for the returned arrays.
132+
xp : array_namespace, optional
133+
The standard-compatible namespace for `x`. Default: infer.
129134
130135
Returns
131136
-------
@@ -175,3 +180,33 @@ def asarrays(
175180
xa, xb = xp.asarray(a), xp.asarray(b)
176181

177182
return (xb, xa) if swap else (xa, xb)
183+
184+
185+
def meta_namespace(
186+
*arrays: Array | int | float | complex | bool | None,
187+
xp: ModuleType | None = None,
188+
) -> ModuleType:
189+
"""
190+
Get the namespace of Dask chunks.
191+
192+
On all other backends, just return the namespace of the arrays.
193+
194+
Parameters
195+
----------
196+
*arrays : Array | int | float | complex | bool | None
197+
Input arrays.
198+
xp : array_namespace, optional
199+
The standard-compatible namespace for `x`. Default: infer.
200+
201+
Returns
202+
-------
203+
array_namespace
204+
If xp is Dask, the namespace of the Dask chunks;
205+
otherwise, the namespace of the arrays.
206+
"""
207+
xp = array_namespace(*arrays) if xp is None else xp
208+
if not is_dask_namespace(xp):
209+
return xp
210+
# Quietly skip scalars and None's
211+
metas = [getattr(a, "_meta", None) for a in arrays]
212+
return array_namespace(*metas)

‎src/array_api_extra/_lib/_utils/_typing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
# To be changed to a Protocol later (see data-apis/array-api#589)
66
Array = Any # type: ignore[no-any-explicit]
7+
DType = Any # type: ignore[no-any-explicit]
78
Device = Any # type: ignore[no-any-explicit]
89
Index = Any # type: ignore[no-any-explicit]
910

10-
__all__ = ["Array", "Device", "Index"]
11+
__all__ = ["Array", "DType", "Device", "Index"]

‎src/array_api_extra/testing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
2020

21-
if TYPE_CHECKING:
21+
if TYPE_CHECKING: # pragma: no cover
2222
# TODO move ParamSpec outside TYPE_CHECKING
2323
# depends on scikit-learn abandoning Python 3.9
2424
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
@@ -169,7 +169,7 @@ def xp(request, monkeypatch):
169169
Pytest fixture, as acquired by the test itself or by one of its fixtures.
170170
monkeypatch : pytest.MonkeyPatch
171171
Pytest fixture, as acquired by the test itself or by one of its fixtures.
172-
xp : module
172+
xp : array_namespace
173173
Array namespace to be tested.
174174
175175
See Also

‎tests/conftest.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ def xp(
115115
if library == Backend.NUMPY_READONLY:
116116
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
117117
xp = pytest.importorskip(library.value)
118+
# Possibly wrap module with array_api_compat
119+
xp = array_namespace(xp.empty(0))
118120

121+
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
122+
# in the global scope of the module containing the test function.
119123
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
120124

121125
if library == Backend.JAX:
@@ -124,8 +128,18 @@ def xp(
124128
# suppress unused-ignore to run mypy in -e lint as well as -e dev
125129
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
126130

127-
# Possibly wrap module with array_api_compat
128-
return array_namespace(xp.empty(0))
131+
return xp
132+
133+
134+
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
135+
def da(
136+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
137+
) -> ModuleType: # numpydoc ignore=PR01,RT01
138+
"""Variant of the `xp` fixture that only yields dask.array."""
139+
xp = pytest.importorskip("dask.array")
140+
xp = array_namespace(xp.empty(0))
141+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
142+
return xp
129143

130144

131145
@pytest.fixture

‎tests/test_at.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -241,18 +241,7 @@ def test_bool_mask_nd(xp: ModuleType):
241241
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
242242

243243

244-
@pytest.mark.parametrize(
245-
"bool_mask",
246-
[
247-
False,
248-
pytest.param(
249-
True,
250-
marks=pytest.mark.xfail_xp_backend(
251-
Backend.DASK, reason="FIXME need scipy's lazywhere"
252-
),
253-
),
254-
],
255-
)
244+
@pytest.mark.parametrize("bool_mask", [False, True])
256245
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
257246
x = xp.asarray([math.inf, 1.0, 2.0])
258247
idx = ~xp.isinf(x) if bool_mask else slice(1, None)

‎tests/test_funcs.py

+175-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import contextlib
22
import math
33
import warnings
4+
from collections.abc import Callable
45
from types import ModuleType
56

67
import numpy as np
78
import pytest
89

910
from array_api_extra import (
11+
apply_where,
1012
at,
1113
atleast_nd,
1214
broadcast_shapes,
@@ -43,6 +45,179 @@
4345
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")
4446

4547

48+
def apply_where_jit( # type: ignore[no-any-explicit]
49+
cond: Array,
50+
f1: Callable[..., Array],
51+
f2: Callable[..., Array] | None,
52+
args: Array | tuple[Array, ...],
53+
fill_value: Array | int | float | complex | bool | None = None,
54+
xp: ModuleType | None = None,
55+
) -> Array:
56+
"""
57+
Work around jax.jit's inability to handle variadic positional arguments.
58+
59+
This is a lazy_xp_function artefact for when jax.jit is applied directly
60+
to apply_where, which would not happen in real life.
61+
"""
62+
if f2 is None:
63+
return apply_where(cond, f1, args, fill_value=fill_value, xp=xp)
64+
assert fill_value is None
65+
return apply_where(cond, f1, f2, args, xp=xp)
66+
67+
68+
lazy_xp_function(apply_where_jit, static_argnames=("f1", "f2", "xp"))
69+
70+
71+
class TestApplyWhere:
72+
@staticmethod
73+
def f1(x: Array, y: Array | int = 10) -> Array:
74+
return x + y
75+
76+
@staticmethod
77+
def f2(x: Array, y: Array | int = 10) -> Array:
78+
return x - y
79+
80+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
81+
def test_f1_f2(self, xp: ModuleType):
82+
x = xp.asarray([1, 2, 3, 4])
83+
cond = x % 2 == 0
84+
actual = apply_where_jit(cond, self.f1, self.f2, x)
85+
expect = xp.where(cond, self.f1(x), self.f2(x))
86+
xp_assert_equal(actual, expect)
87+
88+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
89+
def test_fill_value(self, xp: ModuleType):
90+
x = xp.asarray([1, 2, 3, 4])
91+
cond = x % 2 == 0
92+
actual = apply_where_jit(x % 2 == 0, self.f1, None, x, fill_value=0)
93+
expect = xp.where(cond, self.f1(x), xp.asarray(0))
94+
xp_assert_equal(actual, expect)
95+
96+
actual = apply_where_jit(x % 2 == 0, self.f1, None, x, fill_value=xp.asarray(0))
97+
xp_assert_equal(actual, expect)
98+
99+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
100+
def test_args_tuple(self, xp: ModuleType):
101+
x = xp.asarray([1, 2, 3, 4])
102+
y = xp.asarray([10, 20, 30, 40])
103+
cond = x % 2 == 0
104+
actual = apply_where_jit(cond, self.f1, self.f2, (x, y))
105+
expect = xp.where(cond, self.f1(x, y), self.f2(x, y))
106+
xp_assert_equal(actual, expect)
107+
108+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
109+
def test_broadcast(self, xp: ModuleType):
110+
x = xp.asarray([1, 2])
111+
y = xp.asarray([[10], [20], [30]])
112+
cond = xp.broadcast_to(xp.asarray(True), (4, 1, 1))
113+
114+
actual = apply_where_jit(cond, self.f1, self.f2, (x, y))
115+
expect = xp.where(cond, self.f1(x, y), self.f2(x, y))
116+
xp_assert_equal(actual, expect)
117+
118+
actual = apply_where_jit(
119+
cond,
120+
lambda x, _: x, # pyright: ignore[reportUnknownArgumentType]
121+
lambda _, y: y, # pyright: ignore[reportUnknownArgumentType]
122+
(x, y),
123+
)
124+
expect = xp.where(cond, x, y)
125+
xp_assert_equal(actual, expect)
126+
127+
# Shaped fill_value
128+
actual = apply_where_jit(cond, self.f1, None, x, fill_value=y)
129+
expect = xp.where(cond, self.f1(x), y)
130+
xp_assert_equal(actual, expect)
131+
132+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
133+
def test_dtype_propagation(self, xp: ModuleType, library: Backend):
134+
x = xp.asarray([1, 2], dtype=xp.int8)
135+
y = xp.asarray([3, 4], dtype=xp.int16)
136+
cond = x % 2 == 0
137+
138+
mxp = np if library is Backend.DASK else xp
139+
actual = apply_where_jit(
140+
cond,
141+
self.f1,
142+
lambda x, y: mxp.astype(x - y, xp.int64), # pyright: ignore[reportUnknownArgumentType]
143+
(x, y),
144+
)
145+
assert actual.dtype == xp.int64
146+
147+
actual = apply_where_jit(cond, self.f1, None, y, fill_value=5)
148+
assert actual.dtype == xp.int16
149+
150+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
151+
@pytest.mark.parametrize("fill_value_raw", [3, [3, 4]])
152+
@pytest.mark.parametrize(
153+
("fill_value_dtype", "expect_dtype"), [("int32", "int32"), ("int8", "int16")]
154+
)
155+
def test_dtype_propagation_fill_value(
156+
self,
157+
xp: ModuleType,
158+
fill_value_raw: int | list[int],
159+
fill_value_dtype: str,
160+
expect_dtype: str,
161+
):
162+
x = xp.asarray([1, 2], dtype=xp.int16)
163+
cond = x % 2 == 0
164+
fill_value = xp.asarray(fill_value_raw, dtype=getattr(xp, fill_value_dtype))
165+
166+
actual = apply_where_jit(cond, self.f1, None, x, fill_value=fill_value)
167+
assert actual.dtype == getattr(xp, expect_dtype)
168+
169+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
170+
def test_dont_overwrite_fill_value(self, xp: ModuleType):
171+
x = xp.asarray([1, 2])
172+
fill_value = xp.asarray([100, 200])
173+
actual = apply_where_jit(x % 2 == 0, self.f1, None, x, fill_value=fill_value)
174+
xp_assert_equal(actual, xp.asarray([100, 12]))
175+
xp_assert_equal(fill_value, xp.asarray([100, 200]))
176+
177+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
178+
def test_dont_run_on_false(self, xp: ModuleType):
179+
x = xp.asarray([1.0, 2.0, 0.0])
180+
y = xp.asarray([0.0, 3.0, 4.0])
181+
# On NumPy, division by zero will trigger warnings
182+
actual = apply_where_jit(
183+
x == 0,
184+
lambda x, y: x / y, # pyright: ignore[reportUnknownArgumentType]
185+
lambda x, y: y / x, # pyright: ignore[reportUnknownArgumentType]
186+
(x, y),
187+
)
188+
xp_assert_equal(actual, xp.asarray([0.0, 1.5, 0.0]))
189+
190+
def test_bad_args(self, xp: ModuleType):
191+
x = xp.asarray([1, 2, 3, 4])
192+
cond = x % 2 == 0
193+
# Neither f2 nor fill_value
194+
with pytest.raises(TypeError, match="Exactly one of"):
195+
apply_where(cond, self.f1, x) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
196+
# Both f2 and fill_value
197+
with pytest.raises(TypeError, match="Exactly one of"):
198+
apply_where(cond, self.f1, self.f2, x, fill_value=0) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
199+
# Multiple args; forgot to wrap them in a tuple
200+
with pytest.raises(TypeError, match="takes from 3 to 4 positional arguments"):
201+
apply_where(cond, self.f1, self.f2, x, x) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
202+
with pytest.raises(TypeError, match="callable"):
203+
apply_where(cond, self.f1, x, x, fill_value=0) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
204+
205+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
206+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
207+
def test_xp(self, xp: ModuleType):
208+
x = xp.asarray([1, 2, 3, 4])
209+
cond = x % 2 == 0
210+
actual = apply_where_jit(cond, self.f1, self.f2, x, xp=xp)
211+
expect = xp.where(cond, self.f1(x), self.f2(x))
212+
xp_assert_equal(actual, expect)
213+
214+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
215+
def test_device(self, xp: ModuleType, device: Device):
216+
x = xp.asarray([1, 2, 3, 4], device=device)
217+
y = apply_where_jit(x % 2 == 0, self.f1, self.f2, x)
218+
assert get_device(y) == device
219+
220+
46221
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
47222
class TestAtLeastND:
48223
def test_0D(self, xp: ModuleType):
@@ -334,8 +509,6 @@ def test_xp(self, xp: ModuleType):
334509

335510
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
336511
class TestIsClose:
337-
# FIXME use lazywhere to avoid warnings on inf
338-
@pytest.mark.filterwarnings("ignore:invalid value encountered")
339512
@pytest.mark.parametrize("swap", [False, True])
340513
@pytest.mark.parametrize(
341514
("a", "b"),
@@ -394,8 +567,6 @@ def test_broadcast(self, dtype: str, xp: ModuleType):
394567

395568
xp_assert_equal(actual, expect)
396569

397-
# FIXME use lazywhere to avoid warnings on inf
398-
@pytest.mark.filterwarnings("ignore:invalid value encountered")
399570
def test_some_inf(self, xp: ModuleType):
400571
a = xp.asarray([0.0, 1.0, xp.inf, xp.inf, xp.inf])
401572
b = xp.asarray([1e-9, 1.0, xp.inf, -xp.inf, 2.0])

‎tests/test_utils.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
from array_api_extra._lib import Backend
77
from array_api_extra._lib._testing import xp_assert_equal
8+
from array_api_extra._lib._utils._compat import array_namespace
89
from array_api_extra._lib._utils._compat import device as get_device
9-
from array_api_extra._lib._utils._helpers import asarrays, in1d
10+
from array_api_extra._lib._utils._helpers import asarrays, in1d, meta_namespace
1011
from array_api_extra._lib._utils._typing import Device
1112
from array_api_extra.testing import lazy_xp_function
1213

@@ -15,6 +16,8 @@
1516
# FIXME calls xp.unique_values without size
1617
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
1718

19+
np_compat = array_namespace(np.empty(0))
20+
1821

1922
class TestIn1D:
2023
@pytest.mark.xfail_xp_backend(
@@ -151,3 +154,21 @@ def test_asarrays_numpy_generics(dtype: type):
151154
xa, xb = asarrays(a, 0, xp=np)
152155
assert xa.dtype == dtype
153156
assert xb.dtype == dtype
157+
158+
159+
class TestMetaNamespace:
160+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="namespace tests")
161+
def test_basic(self, xp: ModuleType, library: Backend):
162+
args = None, xp.asarray(0), 1
163+
expect = np_compat if library is Backend.DASK else xp
164+
assert meta_namespace(*args) is expect
165+
166+
def test_dask_metas(self, da: ModuleType):
167+
cp = pytest.importorskip("cupy")
168+
cp_compat = array_namespace(cp.empty(0))
169+
args = None, da.from_array(cp.asarray(0)), 1
170+
assert meta_namespace(*args) is cp_compat
171+
172+
def test_xp(self, xp: ModuleType):
173+
args = None, xp.asarray(0), 1
174+
assert meta_namespace(*args, xp=xp) in (xp, np_compat)

0 commit comments

Comments
 (0)
Please sign in to comment.