From c89ba39e87bf2e488141cbc3c6690b7886ec4211 Mon Sep 17 00:00:00 2001 From: ago109 Date: Mon, 1 Jul 2024 14:10:53 -0400 Subject: [PATCH] edit to lif --- ngclearn/components/neurons/spiking/LIFCell.py | 4 ++-- ngclearn/utils/surrogate_fx.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 0358b494..37c51df9 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -5,7 +5,7 @@ from ngclearn.components.jaxComponent import JaxComponent from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngclearn.utils.surrogate_fx import straight_through_estimator +from ngclearn.utils.surrogate_fx import straight_through_estimator, triangular_estimator @jit def _update_times(t, s, tols): @@ -237,7 +237,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.n_units = n_units ## set up surrogate function for spike emission - self.spike_fx, self.d_spike_fx = straight_through_estimator() + self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator() ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) diff --git a/ngclearn/utils/surrogate_fx.py b/ngclearn/utils/surrogate_fx.py index 84677eee..3371829f 100644 --- a/ngclearn/utils/surrogate_fx.py +++ b/ngclearn/utils/surrogate_fx.py @@ -53,7 +53,6 @@ def d_spike_fx(v, thr): mask = (v < thr).astype(jnp.float32) dfx = mask * thr - (1. - mask) * thr return dfx - return dfx if get_surr_fx == True: return spike_fx, spike_fx, d_spike_fx else: