Skip to content

Commit e307525

Browse files
committed
Tweak unit test
1 parent 6458d8d commit e307525

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/test_funcs.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,17 @@ def test_hypothesis( # type: ignore[no-any-decorated]
229229
cond_shape, *shapes = input_shapes
230230

231231
# cupy/cupy#8382
232-
elements = {"allow_subnormal": False} if library is Backend.CUPY else None
232+
# https://github.com/jax-ml/jax/issues/26658
233+
elements = {"allow_subnormal": library not in (Backend.CUPY, Backend.JAX)}
233234

234235
fill_value = xp.asarray(
235236
data.draw(npst.arrays(dtype=dtype, shape=(), elements=elements))
236237
)
237238
float_fill_value = float(fill_value)
238239
arrays = tuple(
239-
xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=shape)))
240+
xp.asarray(
241+
data.draw(npst.arrays(dtype=dtype, shape=shape, elements=elements))
242+
)
240243
for shape in shapes
241244
)
242245

@@ -258,12 +261,9 @@ def f2(*args: Array) -> Array:
258261
# TODO remove asarrays once all backends support Array API 2024.12
259262
ref3 = xp.where(cond, *asarrays(f1(*arrays), float_fill_value, xp=xp))
260263

261-
# https://github.com/jax-ml/jax/issues/26658
262-
atol = 1e-300 if library is Backend.JAX else 0
263-
264-
xp_assert_close(res1, ref1, atol=atol, rtol=2e-16)
265-
xp_assert_close(res2, ref2, atol=atol, rtol=2e-16)
266-
xp_assert_close(res3, ref3, atol=atol, rtol=2e-16)
264+
xp_assert_close(res1, ref1, rtol=2e-16)
265+
xp_assert_equal(res2, ref2)
266+
xp_assert_equal(res3, ref3)
267267

268268

269269
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")

0 commit comments

Comments
 (0)