Skip to content

Commit a2d2a6d

Browse files
committed
added surrogate compartment to LIFCell
1 parent e678750 commit a2d2a6d

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ngclearn.components.jaxComponent import JaxComponent
66
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
77
step_euler, step_rk2
8+
from ngclearn.utils.surrogate_fx import straight_through_estimator
89

910
@jit
1011
def _update_times(t, s, tols):
@@ -235,6 +236,9 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
235236
self.batch_size = 1
236237
self.n_units = n_units
237238

239+
## set up surrogate function for spike emission
240+
self.spike_fx, self.d_spike_fx = straight_through_estimator()
241+
238242
## Compartment setup
239243
restVals = jnp.zeros((self.batch_size, self.n_units))
240244
thr0 = 0.
@@ -249,17 +253,19 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
249253
self.rfr = Compartment(restVals + self.refract_T)
250254
self.thr_theta = Compartment(restVals + thr0)
251255
self.tols = Compartment(restVals) ## time-of-last-spike
256+
self.surrogate = Compartment(restVals + 1.) ## surrogate signal
252257

253258
@staticmethod
254259
def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
255-
thr, tau_theta, theta_plus, one_spike, intgFlag,
260+
thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx,
256261
key, j, v, s, rfr, thr_theta, tols):
257262
skey = None ## this is an empty dkey if single_spike mode turned off
258263
if one_spike: ## old code ~> if self.one_spike is False:
259264
key, skey = random.split(key, 2)
260265
## run one integration step for neuronal dynamics
261266
#j = _modify_current(j, dt, tau_m, R_m) ## re-scale current in prep for volt ODE
262267
j = j * R_m
268+
surrogate = d_spike_fx(j)
263269
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
264270
tau_m, v_rest, v_reset, v_decay, refract_T,
265271
intgFlag)
@@ -268,17 +274,18 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
268274
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
269275
## update tols
270276
tols = _update_times(t, s, tols)
271-
return v, s, raw_spikes, rfr, thr_theta, tols, key
277+
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
272278

273279
@resolver(_advance_state)
274-
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key):
280+
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key, surrogate):
275281
self.v.set(v)
276282
self.s.set(s)
277283
self.s_raw.set(s_raw)
278284
self.rfr.set(rfr)
279285
self.thr_theta.set(thr_theta)
280286
self.tols.set(tols)
281287
self.key.set(key)
288+
self.surrogate.set(surrogate)
282289

283290
@staticmethod
284291
def _reset(batch_size, n_units, v_rest, refract_T):
@@ -290,17 +297,19 @@ def _reset(batch_size, n_units, v_rest, refract_T):
290297
rfr = restVals + refract_T
291298
#thr_theta = restVals ## do not reset thr_theta
292299
tols = restVals #+ 0
293-
return j, v, s, s_raw, rfr, tols
300+
surrogate = restVals + 1.
301+
return j, v, s, s_raw, rfr, tols, surrogate
294302

295303
@resolver(_reset)
296-
def reset(self, j, v, s, s_raw, rfr, tols):
304+
def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
297305
self.j.set(j)
298306
self.v.set(v)
299307
self.s.set(s)
300308
self.s_raw.set(s_raw)
301309
self.rfr.set(rfr)
302310
#self.thr_theta.set(thr_theta)
303311
self.tols.set(tols)
312+
self.surrogate.set(surrogate)
304313

305314
def save(self, directory, **kwargs):
306315
file_name = directory + "/" + self.name + ".npz"

0 commit comments

Comments
 (0)