1
1
import contextlib
2
2
import math
3
3
import warnings
4
- from collections .abc import Callable
5
4
from types import ModuleType
6
5
7
6
import hypothesis
36
35
# some xp backends are untyped
37
36
# mypy: disable-error-code=no-untyped-def
38
37
38
+ lazy_xp_function (apply_where , static_argnums = (2 , 3 ), static_argnames = "xp" )
39
39
lazy_xp_function (atleast_nd , static_argnames = ("ndim" , "xp" ))
40
40
lazy_xp_function (cov , static_argnames = "xp" )
41
41
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
50
50
lazy_xp_function (sinc , jax_jit = False , static_argnames = "xp" )
51
51
52
52
53
- def apply_where_jit ( # type: ignore[no-any-explicit]
54
- cond : Array ,
55
- f1 : Callable [..., Array ],
56
- f2 : Callable [..., Array ] | None ,
57
- args : Array | tuple [Array , ...],
58
- fill_value : Array | int | float | complex | bool | None = None ,
59
- xp : ModuleType | None = None ,
60
- ) -> Array :
61
- """
62
- Work around jax.jit's inability to handle variadic positional arguments.
63
-
64
- This is a lazy_xp_function artefact for when jax.jit is applied directly
65
- to apply_where, which would not happen in real life.
66
- """
67
- if f2 is None :
68
- return apply_where (cond , f1 , args , fill_value = fill_value , xp = xp )
69
- assert fill_value is None
70
- return apply_where (cond , f1 , f2 , args , xp = xp )
71
-
72
-
73
- lazy_xp_function (apply_where_jit , static_argnames = ("f1" , "f2" , "xp" ))
74
-
75
-
76
53
class TestApplyWhere :
77
54
@staticmethod
78
55
def f1 (x : Array , y : Array | int = 10 ) -> Array :
@@ -86,27 +63,27 @@ def f2(x: Array, y: Array | int = 10) -> Array:
86
63
def test_f1_f2 (self , xp : ModuleType ):
87
64
x = xp .asarray ([1 , 2 , 3 , 4 ])
88
65
cond = x % 2 == 0
89
- actual = apply_where_jit (cond , self .f1 , self .f2 , x )
66
+ actual = apply_where (cond , x , self .f1 , self .f2 )
90
67
expect = xp .where (cond , self .f1 (x ), self .f2 (x ))
91
68
xp_assert_equal (actual , expect )
92
69
93
70
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
94
71
def test_fill_value (self , xp : ModuleType ):
95
72
x = xp .asarray ([1 , 2 , 3 , 4 ])
96
73
cond = x % 2 == 0
97
- actual = apply_where_jit (x % 2 == 0 , self .f1 , None , x , fill_value = 0 )
74
+ actual = apply_where (x % 2 == 0 , x , self .f1 , fill_value = 0 )
98
75
expect = xp .where (cond , self .f1 (x ), xp .asarray (0 ))
99
76
xp_assert_equal (actual , expect )
100
77
101
- actual = apply_where_jit (x % 2 == 0 , self .f1 , None , x , fill_value = xp .asarray (0 ))
78
+ actual = apply_where (x % 2 == 0 , x , self .f1 , fill_value = xp .asarray (0 ))
102
79
xp_assert_equal (actual , expect )
103
80
104
81
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
105
82
def test_args_tuple (self , xp : ModuleType ):
106
83
x = xp .asarray ([1 , 2 , 3 , 4 ])
107
84
y = xp .asarray ([10 , 20 , 30 , 40 ])
108
85
cond = x % 2 == 0
109
- actual = apply_where_jit (cond , self .f1 , self .f2 , ( x , y ) )
86
+ actual = apply_where (cond , ( x , y ), self .f1 , self .f2 )
110
87
expect = xp .where (cond , self .f1 (x , y ), self .f2 (x , y ))
111
88
xp_assert_equal (actual , expect )
112
89
@@ -116,21 +93,21 @@ def test_broadcast(self, xp: ModuleType):
116
93
y = xp .asarray ([[10 ], [20 ], [30 ]])
117
94
cond = xp .broadcast_to (xp .asarray (True ), (4 , 1 , 1 ))
118
95
119
- actual = apply_where_jit (cond , self .f1 , self .f2 , ( x , y ) )
96
+ actual = apply_where (cond , ( x , y ), self .f1 , self .f2 )
120
97
expect = xp .where (cond , self .f1 (x , y ), self .f2 (x , y ))
121
98
xp_assert_equal (actual , expect )
122
99
123
- actual = apply_where_jit (
100
+ actual = apply_where (
124
101
cond ,
102
+ (x , y ),
125
103
lambda x , _ : x , # pyright: ignore[reportUnknownArgumentType]
126
104
lambda _ , y : y , # pyright: ignore[reportUnknownArgumentType]
127
- (x , y ),
128
105
)
129
106
expect = xp .where (cond , x , y )
130
107
xp_assert_equal (actual , expect )
131
108
132
109
# Shaped fill_value
133
- actual = apply_where_jit (cond , self .f1 , None , x , fill_value = y )
110
+ actual = apply_where (cond , x , self .f1 , fill_value = y )
134
111
expect = xp .where (cond , self .f1 (x ), y )
135
112
xp_assert_equal (actual , expect )
136
113
@@ -141,15 +118,15 @@ def test_dtype_propagation(self, xp: ModuleType, library: Backend):
141
118
cond = x % 2 == 0
142
119
143
120
mxp = np if library is Backend .DASK else xp
144
- actual = apply_where_jit (
121
+ actual = apply_where (
145
122
cond ,
123
+ (x , y ),
146
124
self .f1 ,
147
125
lambda x , y : mxp .astype (x - y , xp .int64 ), # pyright: ignore[reportUnknownArgumentType]
148
- (x , y ),
149
126
)
150
127
assert actual .dtype == xp .int64
151
128
152
- actual = apply_where_jit (cond , self .f1 , None , y , fill_value = 5 )
129
+ actual = apply_where (cond , y , self .f1 , fill_value = 5 )
153
130
assert actual .dtype == xp .int16
154
131
155
132
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
@@ -168,14 +145,14 @@ def test_dtype_propagation_fill_value(
168
145
cond = x % 2 == 0
169
146
fill_value = xp .asarray (fill_value_raw , dtype = getattr (xp , fill_value_dtype ))
170
147
171
- actual = apply_where_jit (cond , self .f1 , None , x , fill_value = fill_value )
148
+ actual = apply_where (cond , x , self .f1 , fill_value = fill_value )
172
149
assert actual .dtype == getattr (xp , expect_dtype )
173
150
174
151
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
175
152
def test_dont_overwrite_fill_value (self , xp : ModuleType ):
176
153
x = xp .asarray ([1 , 2 ])
177
154
fill_value = xp .asarray ([100 , 200 ])
178
- actual = apply_where_jit (x % 2 == 0 , self .f1 , None , x , fill_value = fill_value )
155
+ actual = apply_where (x % 2 == 0 , x , self .f1 , fill_value = fill_value )
179
156
xp_assert_equal (actual , xp .asarray ([100 , 12 ]))
180
157
xp_assert_equal (fill_value , xp .asarray ([100 , 200 ]))
181
158
@@ -184,11 +161,11 @@ def test_dont_run_on_false(self, xp: ModuleType):
184
161
x = xp .asarray ([1.0 , 2.0 , 0.0 ])
185
162
y = xp .asarray ([0.0 , 3.0 , 4.0 ])
186
163
# On NumPy, division by zero will trigger warnings
187
- actual = apply_where_jit (
164
+ actual = apply_where (
188
165
x == 0 ,
166
+ (x , y ),
189
167
lambda x , y : x / y , # pyright: ignore[reportUnknownArgumentType]
190
168
lambda x , y : y / x , # pyright: ignore[reportUnknownArgumentType]
191
- (x , y ),
192
169
)
193
170
xp_assert_equal (actual , xp .asarray ([0.0 , 1.5 , 0.0 ]))
194
171
@@ -197,29 +174,28 @@ def test_bad_args(self, xp: ModuleType):
197
174
cond = x % 2 == 0
198
175
# Neither f2 nor fill_value
199
176
with pytest .raises (TypeError , match = "Exactly one of" ):
200
- apply_where (cond , self .f1 , x ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
177
+ apply_where (cond , x , self .f1 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
201
178
# Both f2 and fill_value
202
179
with pytest .raises (TypeError , match = "Exactly one of" ):
203
- apply_where (cond , self .f1 , self .f2 , x , fill_value = 0 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
204
- # Multiple args; forgot to wrap them in a tuple
205
- with pytest .raises (TypeError , match = "takes from 3 to 4 positional arguments" ):
206
- apply_where (cond , self .f1 , self .f2 , x , x ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
207
- with pytest .raises (TypeError , match = "callable" ):
208
- apply_where (cond , self .f1 , x , x , fill_value = 0 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
180
+ apply_where (cond , x , self .f1 , self .f2 , fill_value = 0 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
209
181
210
182
@pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
211
183
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
212
184
def test_xp (self , xp : ModuleType ):
213
185
x = xp .asarray ([1 , 2 , 3 , 4 ])
214
186
cond = x % 2 == 0
215
- actual = apply_where_jit (cond , self .f1 , self .f2 , x , xp = xp )
187
+ actual = apply_where (cond , x , self .f1 , self .f2 , xp = xp )
216
188
expect = xp .where (cond , self .f1 (x ), self .f2 (x ))
217
189
xp_assert_equal (actual , expect )
218
190
219
191
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "read-only without .at" )
220
192
def test_device (self , xp : ModuleType , device : Device ):
221
193
x = xp .asarray ([1 , 2 , 3 , 4 ], device = device )
222
- y = apply_where_jit (x % 2 == 0 , self .f1 , self .f2 , x )
194
+ y = apply_where (x % 2 == 0 , x , self .f1 , self .f2 )
195
+ assert get_device (y ) == device
196
+ y = apply_where (x % 2 == 0 , x , self .f1 , fill_value = 0 )
197
+ assert get_device (y ) == device
198
+ y = apply_where (x % 2 == 0 , x , self .f1 , fill_value = x )
223
199
assert get_device (y ) == device
224
200
225
201
# skip instead of xfail in order not to waste time
@@ -273,10 +249,9 @@ def f2(*args: Array) -> Array:
273
249
rng = np .random .default_rng (rng_seed )
274
250
cond = xp .asarray (rng .random (size = cond_shape ) > p )
275
251
276
- # Use apply_where instead of apply_where_jit to speed the test up
277
- res1 = apply_where (cond , f1 , arrays , fill_value = fill_value )
278
- res2 = apply_where (cond , f1 , f2 , arrays )
279
- res3 = apply_where (cond , f1 , arrays , fill_value = float_fill_value )
252
+ res1 = apply_where (cond , arrays , f1 , fill_value = fill_value )
253
+ res2 = apply_where (cond , arrays , f1 , f2 )
254
+ res3 = apply_where (cond , arrays , f1 , fill_value = float_fill_value )
280
255
281
256
ref1 = xp .where (cond , f1 (* arrays ), fill_value )
282
257
ref2 = xp .where (cond , f1 (* arrays ), f2 (* arrays ))
0 commit comments