|
5 | 5 | from ngclearn.components.jaxComponent import JaxComponent
|
6 | 6 | from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
|
7 | 7 | step_euler, step_rk2
|
8 |
| -from ngclearn.utils.surrogate_fx import straight_through_estimator, triangular_estimator |
| 8 | +from ngclearn.utils.surrogate_fx import arctan_estimator, triangular_estimator |
9 | 9 |
|
10 | 10 | @jit
|
11 | 11 | def _update_times(t, s, tols):
|
@@ -237,6 +237,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
|
237 | 237 | self.n_units = n_units
|
238 | 238 |
|
239 | 239 | ## set up surrogate function for spike emission
|
| 240 | + #self.spike_fx, self.d_spike_fx = arctan_estimator() # |
240 | 241 | self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()
|
241 | 242 |
|
242 | 243 | ## Compartment setup
|
@@ -264,10 +265,11 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
|
264 | 265 | key, skey = random.split(key, 2)
|
265 | 266 | ## run one integration step for neuronal dynamics
|
266 | 267 | j = j * R_m
|
267 |
| - surrogate = d_spike_fx(j, thr + thr_theta) |
| 268 | + #surrogate = d_spike_fx(v, thr + thr_theta) |
268 | 269 | v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
|
269 | 270 | tau_m, v_rest, v_reset, v_decay, refract_T,
|
270 | 271 | intgFlag)
|
| 272 | + surrogate = d_spike_fx(v, thr + thr_theta) |
271 | 273 | if tau_theta > 0.:
|
272 | 274 | ## run one integration step for threshold dynamics
|
273 | 275 | thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
|
|
0 commit comments