Skip to content

Commit a08ad48

Browse files
authored
[ci] Fix bug in test_ndarray.py when using the latest version of JAX (#708)
* Update test_ndarray.py * Update ndarray.py
1 parent c289edd commit a08ad48

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

brainpy/_src/math/ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _check_tracer(self):
9898
self_value = self.value
9999
if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'):
100100
if len(self_value._trace.main.jaxpr_stack) == 0:
101-
raise RuntimeError('This Array is modified during the transformation. '
101+
raise jax.errors.UnexpectedTracerError('This Array is modified during the transformation. '
102102
'BrainPy only supports transformations for Variable. '
103103
'Please declare it as a Variable.') from jax.core.escaped_tracer_error(self_value, None)
104104
return self_value

brainpy/_src/math/tests/test_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _f(self, b):
6262

6363
def test_tracing(self):
6464
print(self.f(1.))
65-
with self.assertRaises(RuntimeError):
65+
with self.assertRaises(jax.errors.UnexpectedTracerError):
6666
print(self.f(bm.ones(10)))
6767

6868

0 commit comments

Comments
 (0)