5
5
from ngclearn .components .jaxComponent import JaxComponent
6
6
from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
7
7
step_euler , step_rk2
8
+ from ngclearn .utils .surrogate_fx import straight_through_estimator
8
9
9
10
@jit
10
11
def _update_times (t , s , tols ):
@@ -235,6 +236,9 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
235
236
self .batch_size = 1
236
237
self .n_units = n_units
237
238
239
+ ## set up surrogate function for spike emission
240
+ self .spike_fx , self .d_spike_fx = straight_through_estimator ()
241
+
238
242
## Compartment setup
239
243
restVals = jnp .zeros ((self .batch_size , self .n_units ))
240
244
thr0 = 0.
@@ -249,17 +253,19 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
249
253
self .rfr = Compartment (restVals + self .refract_T )
250
254
self .thr_theta = Compartment (restVals + thr0 )
251
255
self .tols = Compartment (restVals ) ## time-of-last-spike
256
+ self .surrogate = Compartment (restVals + 1. ) ## surrogate signal
252
257
253
258
@staticmethod
254
259
def _advance_state (t , dt , tau_m , R_m , v_rest , v_reset , v_decay , refract_T ,
255
- thr , tau_theta , theta_plus , one_spike , intgFlag ,
260
+ thr , tau_theta , theta_plus , one_spike , intgFlag , d_spike_fx ,
256
261
key , j , v , s , rfr , thr_theta , tols ):
257
262
skey = None ## this is an empty dkey if single_spike mode turned off
258
263
if one_spike : ## old code ~> if self.one_spike is False:
259
264
key , skey = random .split (key , 2 )
260
265
## run one integration step for neuronal dynamics
261
266
#j = _modify_current(j, dt, tau_m, R_m) ## re-scale current in prep for volt ODE
262
267
j = j * R_m
268
+ surrogate = d_spike_fx (j )
263
269
v , s , raw_spikes , rfr = _run_cell (dt , j , v , thr , thr_theta , rfr , skey ,
264
270
tau_m , v_rest , v_reset , v_decay , refract_T ,
265
271
intgFlag )
@@ -268,17 +274,18 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
268
274
thr_theta = _update_theta (dt , thr_theta , raw_spikes , tau_theta , theta_plus )
269
275
## update tols
270
276
tols = _update_times (t , s , tols )
271
- return v , s , raw_spikes , rfr , thr_theta , tols , key
277
+ return v , s , raw_spikes , rfr , thr_theta , tols , key , surrogate
272
278
273
279
@resolver (_advance_state )
274
- def advance_state (self , v , s , s_raw , rfr , thr_theta , tols , key ):
280
+ def advance_state (self , v , s , s_raw , rfr , thr_theta , tols , key , surrogate ):
275
281
self .v .set (v )
276
282
self .s .set (s )
277
283
self .s_raw .set (s_raw )
278
284
self .rfr .set (rfr )
279
285
self .thr_theta .set (thr_theta )
280
286
self .tols .set (tols )
281
287
self .key .set (key )
288
+ self .surrogate .set (surrogate )
282
289
283
290
@staticmethod
284
291
def _reset (batch_size , n_units , v_rest , refract_T ):
@@ -290,17 +297,19 @@ def _reset(batch_size, n_units, v_rest, refract_T):
290
297
rfr = restVals + refract_T
291
298
#thr_theta = restVals ## do not reset thr_theta
292
299
tols = restVals #+ 0
293
- return j , v , s , s_raw , rfr , tols
300
+ surrogate = restVals + 1.
301
+ return j , v , s , s_raw , rfr , tols , surrogate
294
302
295
303
@resolver (_reset )
296
- def reset (self , j , v , s , s_raw , rfr , tols ):
304
+ def reset (self , j , v , s , s_raw , rfr , tols , surrogate ):
297
305
self .j .set (j )
298
306
self .v .set (v )
299
307
self .s .set (s )
300
308
self .s_raw .set (s_raw )
301
309
self .rfr .set (rfr )
302
310
#self.thr_theta.set(thr_theta)
303
311
self .tols .set (tols )
312
+ self .surrogate .set (surrogate )
304
313
305
314
def save (self , directory , ** kwargs ):
306
315
file_name = directory + "/" + self .name + ".npz"
0 commit comments