Skip to content

Commit 6458d8d

Browse files
committed
Ignore JAX bug
1 parent ad8c777 commit 6458d8d

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

Diff for: tests/test_funcs.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,12 @@ def f2(*args: Array) -> Array:
258258
# TODO remove asarrays once all backends support Array API 2024.12
259259
ref3 = xp.where(cond, *asarrays(f1(*arrays), float_fill_value, xp=xp))
260260

261-
xp_assert_close(res1, ref1, rtol=2e-16)
262-
xp_assert_equal(res2, ref2)
263-
xp_assert_equal(res3, ref3)
261+
# https://github.com/jax-ml/jax/issues/26658
262+
atol = 1e-300 if library is Backend.JAX else 0
263+
264+
xp_assert_close(res1, ref1, atol=atol, rtol=2e-16)
265+
xp_assert_close(res2, ref2, atol=atol, rtol=2e-16)
266+
xp_assert_close(res3, ref3, atol=atol, rtol=2e-16)
264267

265268

266269
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")

0 commit comments

Comments
 (0)