diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 0358b494..37c51df9 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -5,7 +5,7 @@ from ngclearn.components.jaxComponent import JaxComponent from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngclearn.utils.surrogate_fx import straight_through_estimator +from ngclearn.utils.surrogate_fx import straight_through_estimator, triangular_estimator @jit def _update_times(t, s, tols): @@ -237,7 +237,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.n_units = n_units ## set up surrogate function for spike emission - self.spike_fx, self.d_spike_fx = straight_through_estimator() + self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator() ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) diff --git a/ngclearn/utils/surrogate_fx.py b/ngclearn/utils/surrogate_fx.py index 84677eee..3371829f 100644 --- a/ngclearn/utils/surrogate_fx.py +++ b/ngclearn/utils/surrogate_fx.py @@ -53,7 +53,6 @@ def d_spike_fx(v, thr): mask = (v < thr).astype(jnp.float32) dfx = mask * thr - (1. - mask) * thr return dfx - return dfx if get_surr_fx == True: return spike_fx, spike_fx, d_spike_fx else: