Skip to content

Commit a51e3a7

Browse files
authored
Fix autograd (#687)
* fix autograd bug * fix autograd bug * fix autograd bug * fix autograd bug
1 parent d1a4afb commit a51e3a7

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

brainpy/_src/math/object_transform/autograd.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def __init__(
9494
self.target = target
9595

9696
# transform
97-
self._eval_dyn_vars = False
9897
self._grad_transform = transform
9998
self._dyn_vars = VariableStack()
10099
self._transform = None
@@ -198,20 +197,18 @@ def __call__(self, *args, **kwargs):
198197
)
199198
return self._return(rets)
200199

201-
elif not self._eval_dyn_vars: # evaluate dynamical variables
202-
stack = get_stack_cache(self.target)
203-
if stack is None:
204-
with VariableStack() as stack:
205-
rets = eval_shape(self._transform,
206-
[v.value for v in self._grad_vars], # variables for gradients
207-
{}, # dynamical variables
208-
*args,
209-
**kwargs)
210-
cache_stack(self.target, stack)
211-
200+
# evaluate dynamical variables
201+
stack = get_stack_cache(self.target)
202+
if stack is None:
203+
with VariableStack() as stack:
204+
rets = eval_shape(self._transform,
205+
[v.value for v in self._grad_vars], # variables for gradients
206+
{}, # dynamical variables
207+
*args,
208+
**kwargs)
209+
cache_stack(self.target, stack)
212210
self._dyn_vars = stack
213211
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
214-
self._eval_dyn_vars = True
215212

216213
# if not the outermost transformation
217214
if not stack.is_first_stack():

brainpy/_src/math/object_transform/tests/test_autograd.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ def call(a, b, c):
8686
assert aux[1] == bm.exp(0.1)
8787

8888

89+
def test_grad_jit(self):
90+
def call(a, b, c): return bm.sum(a + b + c)
91+
92+
bm.random.seed(1)
93+
a = bm.ones(10)
94+
b = bm.random.randn(10)
95+
c = bm.random.uniform(size=10)
96+
f_grad = bm.jit(bm.grad(call))
97+
assert (f_grad(a, b, c) == 1.).all()
98+
99+
89100
class TestObjectFuncGrad(unittest.TestCase):
90101
def test_grad_ob1(self):
91102
class Test(bp.BrainPyObject):

0 commit comments

Comments
 (0)