@@ -151,7 +151,8 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
151
151
@deprecate_args (thr_jitter = None )
152
152
def __init__ (self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. ,
153
153
v_reset = - 60. , v_decay = 1. , tau_theta = 1e7 , theta_plus = 0.05 ,
154
- refract_time = 5. , one_spike = False , integration_type = "euler" , ** kwargs ):
154
+ refract_time = 5. , one_spike = False , integration_type = "euler" ,
155
+ surrgoate_type = "straight_through" , ** kwargs ):
155
156
super ().__init__ (name , ** kwargs )
156
157
157
158
## Integration properties
@@ -179,14 +180,13 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
179
180
self .n_units = n_units
180
181
181
182
## set up surrogate function for spike emission
182
- surrgoate_type = "secant_lif"
183
183
if surrgoate_type == "secant_lif" :
184
184
self .spike_fx , self .d_spike_fx = secant_lif_estimator ()
185
185
elif surrgoate_type == "arctan" :
186
186
self .spike_fx , self .d_spike_fx = arctan_estimator ()
187
187
elif surrgoate_type == "triangular" :
188
188
self .spike_fx , self .d_spike_fx = triangular_estimator ()
189
- else : ## default is the straight-through estimator (STE)
189
+ else : ## default: straight_through
190
190
self .spike_fx , self .d_spike_fx = straight_through_estimator ()
191
191
192
192
0 commit comments