Skip to content

Commit

Permalink
edit to lif
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 1, 2024
1 parent 71d9bf6 commit c89ba39
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ngclearn/components/neurons/spiking/LIFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import straight_through_estimator
from ngclearn.utils.surrogate_fx import straight_through_estimator, triangular_estimator

@jit
def _update_times(t, s, tols):
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.n_units = n_units

## set up surrogate function for spike emission
self.spike_fx, self.d_spike_fx = straight_through_estimator()
self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()

## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
Expand Down
1 change: 0 additions & 1 deletion ngclearn/utils/surrogate_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def d_spike_fx(v, thr):
mask = (v < thr).astype(jnp.float32)
dfx = mask * thr - (1. - mask) * thr
return dfx
return dfx
if get_surr_fx == True:
return spike_fx, spike_fx, d_spike_fx
else:
Expand Down

0 comments on commit c89ba39

Please sign in to comment.