@@ -229,14 +229,17 @@ def test_hypothesis( # type: ignore[no-any-decorated]
229
229
cond_shape , * shapes = input_shapes
230
230
231
231
# cupy/cupy#8382
232
- elements = {"allow_subnormal" : False } if library is Backend .CUPY else None
232
+ # https://github.com/jax-ml/jax/issues/26658
233
+ elements = {"allow_subnormal" : library not in (Backend .CUPY , Backend .JAX )}
233
234
234
235
fill_value = xp .asarray (
235
236
data .draw (npst .arrays (dtype = dtype , shape = (), elements = elements ))
236
237
)
237
238
float_fill_value = float (fill_value )
238
239
arrays = tuple (
239
- xp .asarray (data .draw (npst .arrays (dtype = dtype , shape = shape )))
240
+ xp .asarray (
241
+ data .draw (npst .arrays (dtype = dtype , shape = shape , elements = elements ))
242
+ )
240
243
for shape in shapes
241
244
)
242
245
@@ -258,12 +261,9 @@ def f2(*args: Array) -> Array:
258
261
# TODO remove asarrays once all backends support Array API 2024.12
259
262
ref3 = xp .where (cond , * asarrays (f1 (* arrays ), float_fill_value , xp = xp ))
260
263
261
- # https://github.com/jax-ml/jax/issues/26658
262
- atol = 1e-300 if library is Backend .JAX else 0
263
-
264
- xp_assert_close (res1 , ref1 , atol = atol , rtol = 2e-16 )
265
- xp_assert_close (res2 , ref2 , atol = atol , rtol = 2e-16 )
266
- xp_assert_close (res3 , ref3 , atol = atol , rtol = 2e-16 )
264
+ xp_assert_close (res1 , ref1 , rtol = 2e-16 )
265
+ xp_assert_equal (res2 , ref2 )
266
+ xp_assert_equal (res3 , ref3 )
267
267
268
268
269
269
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
0 commit comments