Skip to content

Commit ad8c777

Browse files
committed
Move args to second position
1 parent 43dc1f3 commit ad8c777

File tree

3 files changed

+39
-74
lines changed

3 files changed

+39
-74
lines changed

Diff for: src/array_api_extra/_lib/_at.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def _op(
299299
if out_of_place_op:
300300
# suppress inf warnings on Dask
301301
out = apply_where(
302-
idx, out_of_place_op, (x, y_xp), fill_value=x, xp=xp
302+
idx, (x, y_xp), out_of_place_op, fill_value=x, xp=xp
303303
)
304304
# Undo int->float promotion on JAX after _AtOp.DIVIDE
305305
out = xp.astype(out, x.dtype, copy=False)

Diff for: src/array_api_extra/_lib/_funcs.py

+11-21
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
@overload
3939
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
4040
cond: Array,
41+
args: Array | tuple[Array, ...],
4142
f1: Callable[..., Array],
4243
f2: Callable[..., Array],
43-
args: Array | tuple[Array, ...],
4444
/,
4545
*,
4646
xp: ModuleType | None = None,
@@ -50,8 +50,8 @@ def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ig
5050
@overload
5151
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
5252
cond: Array,
53-
f1: Callable[..., Array],
5453
args: Array | tuple[Array, ...],
54+
f1: Callable[..., Array],
5555
/,
5656
*,
5757
fill_value: Array | int | float | complex | bool,
@@ -61,9 +61,9 @@ def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ig
6161

6262
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
6363
cond: Array,
64+
args: Array | tuple[Array, ...],
6465
f1: Callable[..., Array],
65-
f2: Callable[..., Array] | Array | tuple[Array], # optional positional argument
66-
args: Array | tuple[Array, ...] | None = None,
66+
f2: Callable[..., Array] | None = None,
6767
/,
6868
*,
6969
fill_value: Array | int | float | complex | bool | None = None,
@@ -79,15 +79,15 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
7979
----------
8080
cond : array
8181
The condition, expressed as a boolean array.
82+
args : Array or tuple of Arrays
83+
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
8284
f1 : callable
8385
Elementwise function of `args`, returning a single array.
8486
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
8587
f2 : callable, optional
8688
Elementwise function of `args`, returning a single array.
8789
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
8890
Mutually exclusive with `fill_value`.
89-
args : Array or tuple of Arrays
90-
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
9191
fill_value : Array or scalar, optional
9292
If provided, value with which to fill output array where `cond` is False.
9393
It does not need to be scalar; it needs however to be broadcastable with
@@ -119,25 +119,15 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
119119
>>> b = xp.asarray([0, 2, 2])
120120
>>> def f(a, b):
121121
... return a // b
122-
>>> apply_where(b != 0, f, (a, b), fill_value=xp.nan)
122+
>>> apply_where(b != 0, (a, b), f, fill_value=xp.nan)
123123
array([ nan, 2., 1.])
124124
"""
125125
# Parse and normalize arguments
126-
mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
127-
if args is None:
128-
f2, args = None, f2
129-
if fill_value is None:
130-
raise TypeError(mutually_exc_msg)
131-
else:
132-
if not callable(f2):
133-
msg = "Third parameter must be a callable, Array, or tuple of Arrays."
134-
raise TypeError(msg)
135-
if fill_value is not None:
136-
raise TypeError(mutually_exc_msg)
137-
126+
if (f2 is None) == (fill_value is None):
127+
msg = "Exactly one of `fill_value` or `f2` must be given."
128+
raise TypeError(msg)
138129
if not isinstance(args, tuple):
139130
args = (args,)
140-
f2 = cast(Callable[..., Array] | None, f2) # type: ignore[no-any-explicit]
141131
args = cast(tuple[Array, ...], args)
142132

143133
xp = array_namespace(cond, *args) if xp is None else xp
@@ -547,10 +537,10 @@ def isclose(
547537
mxp = meta_namespace(a, b, xp=xp)
548538
out = apply_where(
549539
xp.isinf(a) | xp.isinf(b),
540+
(a, b),
550541
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType]
551542
# Note: inf <= inf is True!
552543
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType]
553-
(a, b),
554544
xp=xp,
555545
)
556546
if equal_nan:

Diff for: tests/test_funcs.py

+27-52
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import contextlib
22
import math
33
import warnings
4-
from collections.abc import Callable
54
from types import ModuleType
65

76
import hypothesis
@@ -36,6 +35,7 @@
3635
# some xp backends are untyped
3736
# mypy: disable-error-code=no-untyped-def
3837

38+
lazy_xp_function(apply_where, static_argnums=(2, 3), static_argnames="xp")
3939
lazy_xp_function(atleast_nd, static_argnames=("ndim", "xp"))
4040
lazy_xp_function(cov, static_argnames="xp")
4141
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
@@ -50,29 +50,6 @@
5050
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")
5151

5252

53-
def apply_where_jit( # type: ignore[no-any-explicit]
54-
cond: Array,
55-
f1: Callable[..., Array],
56-
f2: Callable[..., Array] | None,
57-
args: Array | tuple[Array, ...],
58-
fill_value: Array | int | float | complex | bool | None = None,
59-
xp: ModuleType | None = None,
60-
) -> Array:
61-
"""
62-
Work around jax.jit's inability to handle variadic positional arguments.
63-
64-
This is a lazy_xp_function artefact for when jax.jit is applied directly
65-
to apply_where, which would not happen in real life.
66-
"""
67-
if f2 is None:
68-
return apply_where(cond, f1, args, fill_value=fill_value, xp=xp)
69-
assert fill_value is None
70-
return apply_where(cond, f1, f2, args, xp=xp)
71-
72-
73-
lazy_xp_function(apply_where_jit, static_argnames=("f1", "f2", "xp"))
74-
75-
7653
class TestApplyWhere:
7754
@staticmethod
7855
def f1(x: Array, y: Array | int = 10) -> Array:
@@ -86,27 +63,27 @@ def f2(x: Array, y: Array | int = 10) -> Array:
8663
def test_f1_f2(self, xp: ModuleType):
8764
x = xp.asarray([1, 2, 3, 4])
8865
cond = x % 2 == 0
89-
actual = apply_where_jit(cond, self.f1, self.f2, x)
66+
actual = apply_where(cond, x, self.f1, self.f2)
9067
expect = xp.where(cond, self.f1(x), self.f2(x))
9168
xp_assert_equal(actual, expect)
9269

9370
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
9471
def test_fill_value(self, xp: ModuleType):
9572
x = xp.asarray([1, 2, 3, 4])
9673
cond = x % 2 == 0
97-
actual = apply_where_jit(x % 2 == 0, self.f1, None, x, fill_value=0)
74+
actual = apply_where(x % 2 == 0, x, self.f1, fill_value=0)
9875
expect = xp.where(cond, self.f1(x), xp.asarray(0))
9976
xp_assert_equal(actual, expect)
10077

101-
actual = apply_where_jit(x % 2 == 0, self.f1, None, x, fill_value=xp.asarray(0))
78+
actual = apply_where(x % 2 == 0, x, self.f1, fill_value=xp.asarray(0))
10279
xp_assert_equal(actual, expect)
10380

10481
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
10582
def test_args_tuple(self, xp: ModuleType):
10683
x = xp.asarray([1, 2, 3, 4])
10784
y = xp.asarray([10, 20, 30, 40])
10885
cond = x % 2 == 0
109-
actual = apply_where_jit(cond, self.f1, self.f2, (x, y))
86+
actual = apply_where(cond, (x, y), self.f1, self.f2)
11087
expect = xp.where(cond, self.f1(x, y), self.f2(x, y))
11188
xp_assert_equal(actual, expect)
11289

@@ -116,21 +93,21 @@ def test_broadcast(self, xp: ModuleType):
11693
y = xp.asarray([[10], [20], [30]])
11794
cond = xp.broadcast_to(xp.asarray(True), (4, 1, 1))
11895

119-
actual = apply_where_jit(cond, self.f1, self.f2, (x, y))
96+
actual = apply_where(cond, (x, y), self.f1, self.f2)
12097
expect = xp.where(cond, self.f1(x, y), self.f2(x, y))
12198
xp_assert_equal(actual, expect)
12299

123-
actual = apply_where_jit(
100+
actual = apply_where(
124101
cond,
102+
(x, y),
125103
lambda x, _: x, # pyright: ignore[reportUnknownArgumentType]
126104
lambda _, y: y, # pyright: ignore[reportUnknownArgumentType]
127-
(x, y),
128105
)
129106
expect = xp.where(cond, x, y)
130107
xp_assert_equal(actual, expect)
131108

132109
# Shaped fill_value
133-
actual = apply_where_jit(cond, self.f1, None, x, fill_value=y)
110+
actual = apply_where(cond, x, self.f1, fill_value=y)
134111
expect = xp.where(cond, self.f1(x), y)
135112
xp_assert_equal(actual, expect)
136113

@@ -141,15 +118,15 @@ def test_dtype_propagation(self, xp: ModuleType, library: Backend):
141118
cond = x % 2 == 0
142119

143120
mxp = np if library is Backend.DASK else xp
144-
actual = apply_where_jit(
121+
actual = apply_where(
145122
cond,
123+
(x, y),
146124
self.f1,
147125
lambda x, y: mxp.astype(x - y, xp.int64), # pyright: ignore[reportUnknownArgumentType]
148-
(x, y),
149126
)
150127
assert actual.dtype == xp.int64
151128

152-
actual = apply_where_jit(cond, self.f1, None, y, fill_value=5)
129+
actual = apply_where(cond, y, self.f1, fill_value=5)
153130
assert actual.dtype == xp.int16
154131

155132
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
@@ -168,14 +145,14 @@ def test_dtype_propagation_fill_value(
168145
cond = x % 2 == 0
169146
fill_value = xp.asarray(fill_value_raw, dtype=getattr(xp, fill_value_dtype))
170147

171-
actual = apply_where_jit(cond, self.f1, None, x, fill_value=fill_value)
148+
actual = apply_where(cond, x, self.f1, fill_value=fill_value)
172149
assert actual.dtype == getattr(xp, expect_dtype)
173150

174151
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
175152
def test_dont_overwrite_fill_value(self, xp: ModuleType):
176153
x = xp.asarray([1, 2])
177154
fill_value = xp.asarray([100, 200])
178-
actual = apply_where_jit(x % 2 == 0, self.f1, None, x, fill_value=fill_value)
155+
actual = apply_where(x % 2 == 0, x, self.f1, fill_value=fill_value)
179156
xp_assert_equal(actual, xp.asarray([100, 12]))
180157
xp_assert_equal(fill_value, xp.asarray([100, 200]))
181158

@@ -184,11 +161,11 @@ def test_dont_run_on_false(self, xp: ModuleType):
184161
x = xp.asarray([1.0, 2.0, 0.0])
185162
y = xp.asarray([0.0, 3.0, 4.0])
186163
# On NumPy, division by zero will trigger warnings
187-
actual = apply_where_jit(
164+
actual = apply_where(
188165
x == 0,
166+
(x, y),
189167
lambda x, y: x / y, # pyright: ignore[reportUnknownArgumentType]
190168
lambda x, y: y / x, # pyright: ignore[reportUnknownArgumentType]
191-
(x, y),
192169
)
193170
xp_assert_equal(actual, xp.asarray([0.0, 1.5, 0.0]))
194171

@@ -197,29 +174,28 @@ def test_bad_args(self, xp: ModuleType):
197174
cond = x % 2 == 0
198175
# Neither f2 nor fill_value
199176
with pytest.raises(TypeError, match="Exactly one of"):
200-
apply_where(cond, self.f1, x) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
177+
apply_where(cond, x, self.f1) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
201178
# Both f2 and fill_value
202179
with pytest.raises(TypeError, match="Exactly one of"):
203-
apply_where(cond, self.f1, self.f2, x, fill_value=0) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
204-
# Multiple args; forgot to wrap them in a tuple
205-
with pytest.raises(TypeError, match="takes from 3 to 4 positional arguments"):
206-
apply_where(cond, self.f1, self.f2, x, x) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
207-
with pytest.raises(TypeError, match="callable"):
208-
apply_where(cond, self.f1, x, x, fill_value=0) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
180+
apply_where(cond, x, self.f1, self.f2, fill_value=0) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
209181

210182
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
211183
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
212184
def test_xp(self, xp: ModuleType):
213185
x = xp.asarray([1, 2, 3, 4])
214186
cond = x % 2 == 0
215-
actual = apply_where_jit(cond, self.f1, self.f2, x, xp=xp)
187+
actual = apply_where(cond, x, self.f1, self.f2, xp=xp)
216188
expect = xp.where(cond, self.f1(x), self.f2(x))
217189
xp_assert_equal(actual, expect)
218190

219191
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="read-only without .at")
220192
def test_device(self, xp: ModuleType, device: Device):
221193
x = xp.asarray([1, 2, 3, 4], device=device)
222-
y = apply_where_jit(x % 2 == 0, self.f1, self.f2, x)
194+
y = apply_where(x % 2 == 0, x, self.f1, self.f2)
195+
assert get_device(y) == device
196+
y = apply_where(x % 2 == 0, x, self.f1, fill_value=0)
197+
assert get_device(y) == device
198+
y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
223199
assert get_device(y) == device
224200

225201
# skip instead of xfail in order not to waste time
@@ -273,10 +249,9 @@ def f2(*args: Array) -> Array:
273249
rng = np.random.default_rng(rng_seed)
274250
cond = xp.asarray(rng.random(size=cond_shape) > p)
275251

276-
# Use apply_where instead of apply_where_jit to speed the test up
277-
res1 = apply_where(cond, f1, arrays, fill_value=fill_value)
278-
res2 = apply_where(cond, f1, f2, arrays)
279-
res3 = apply_where(cond, f1, arrays, fill_value=float_fill_value)
252+
res1 = apply_where(cond, arrays, f1, fill_value=fill_value)
253+
res2 = apply_where(cond, arrays, f1, f2)
254+
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value)
280255

281256
ref1 = xp.where(cond, f1(*arrays), fill_value)
282257
ref2 = xp.where(cond, f1(*arrays), f2(*arrays))

0 commit comments

Comments
 (0)