Skip to content

Commit a45dbbd

Browse files
committed
edit to lif
1 parent c89ba39 commit a45dbbd

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ngclearn.components.jaxComponent import JaxComponent
66
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
77
step_euler, step_rk2
8-
from ngclearn.utils.surrogate_fx import straight_through_estimator, triangular_estimator
8+
from ngclearn.utils.surrogate_fx import arctan_estimator, triangular_estimator
99

1010
@jit
1111
def _update_times(t, s, tols):
@@ -237,6 +237,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
237237
self.n_units = n_units
238238

239239
## set up surrogate function for spike emission
240+
#self.spike_fx, self.d_spike_fx = arctan_estimator() #
240241
self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()
241242

242243
## Compartment setup
@@ -264,10 +265,11 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
264265
key, skey = random.split(key, 2)
265266
## run one integration step for neuronal dynamics
266267
j = j * R_m
267-
surrogate = d_spike_fx(j, thr + thr_theta)
268+
#surrogate = d_spike_fx(v, thr + thr_theta)
268269
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
269270
tau_m, v_rest, v_reset, v_decay, refract_T,
270271
intgFlag)
272+
surrogate = d_spike_fx(v, thr + thr_theta)
271273
if tau_theta > 0.:
272274
## run one integration step for threshold dynamics
273275
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)

0 commit comments

Comments
 (0)