From a08ad4848c802a7c19e7c97635c4364eeeb9f893 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Mon, 16 Dec 2024 13:25:56 +0800 Subject: [PATCH] [ci] Fix bug in `test_ndarray.py` when using the latest version of JAX (#708) * Update test_ndarray.py * Update ndarray.py --- brainpy/_src/math/ndarray.py | 2 +- brainpy/_src/math/tests/test_ndarray.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index b435415d6..47b81d18e 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -98,7 +98,7 @@ def _check_tracer(self): self_value = self.value if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'): if len(self_value._trace.main.jaxpr_stack) == 0: - raise RuntimeError('This Array is modified during the transformation. ' + raise jax.errors.UnexpectedTracerError('This Array is modified during the transformation. ' 'BrainPy only supports transformations for Variable. ' 'Please declare it as a Variable.') from jax.core.escaped_tracer_error(self_value, None) return self_value diff --git a/brainpy/_src/math/tests/test_ndarray.py b/brainpy/_src/math/tests/test_ndarray.py index a09129129..e9acff357 100644 --- a/brainpy/_src/math/tests/test_ndarray.py +++ b/brainpy/_src/math/tests/test_ndarray.py @@ -62,7 +62,7 @@ def _f(self, b): def test_tracing(self): print(self.f(1.)) - with self.assertRaises(RuntimeError): + with self.assertRaises(jax.errors.UnexpectedTracerError): print(self.f(bm.ones(10)))