diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 9f1e69b4b..55986aed2 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -5,6 +5,7 @@ from ngclearn.components.jaxComponent import JaxComponent from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 +from ngclearn.utils.surrogate_fx import straight_through_estimator @jit def _update_times(t, s, tols): @@ -235,6 +236,9 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.batch_size = 1 self.n_units = n_units + ## set up surrogate function for spike emission + self.spike_fx, self.d_spike_fx = straight_through_estimator() + ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) thr0 = 0. @@ -249,10 +253,11 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.rfr = Compartment(restVals + self.refract_T) self.thr_theta = Compartment(restVals + thr0) self.tols = Compartment(restVals) ## time-of-last-spike + self.surrogate = Compartment(restVals + 1.) ## surrogate signal @staticmethod def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, - thr, tau_theta, theta_plus, one_spike, intgFlag, + thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx, key, j, v, s, rfr, thr_theta, tols): skey = None ## this is an empty dkey if single_spike mode turned off if one_spike: ## old code ~> if self.one_spike is False: @@ -260,6 +265,7 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, ## run one integration step for neuronal dynamics #j = _modify_current(j, dt, tau_m, R_m) ## re-scale current in prep for volt ODE j = j * R_m + surrogate = d_spike_fx(j) v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey, tau_m, v_rest, v_reset, v_decay, refract_T, intgFlag) @@ -268,10 +274,10 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus) ## update tols tols = _update_times(t, s, tols) - return v, s, raw_spikes, rfr, thr_theta, tols, key + return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate @resolver(_advance_state) - def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key): + def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key, surrogate): self.v.set(v) self.s.set(s) self.s_raw.set(s_raw) @@ -279,6 +285,7 @@ def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key): self.thr_theta.set(thr_theta) self.tols.set(tols) self.key.set(key) + self.surrogate.set(surrogate) @staticmethod def _reset(batch_size, n_units, v_rest, refract_T): @@ -290,10 +297,11 @@ def _reset(batch_size, n_units, v_rest, refract_T): rfr = restVals + refract_T #thr_theta = restVals ## do not reset thr_theta tols = restVals #+ 0 - return j, v, s, s_raw, rfr, tols + surrogate = restVals + 1. + return j, v, s, s_raw, rfr, tols, surrogate @resolver(_reset) - def reset(self, j, v, s, s_raw, rfr, tols): + def reset(self, j, v, s, s_raw, rfr, tols, surrogate): self.j.set(j) self.v.set(v) self.s.set(s) @@ -301,6 +309,7 @@ def reset(self, j, v, s, s_raw, rfr, tols): self.rfr.set(rfr) #self.thr_theta.set(thr_theta) self.tols.set(tols) + self.surrogate.set(surrogate) def save(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz"