1
1
from jax import numpy as jnp , random , jit , nn
2
2
from functools import partial
3
3
from ngclearn .utils import tensorstats
4
+ from ngcsimlib .deprecators import deprecate_args
4
5
from ngclearn import resolver , Component , Compartment
5
6
from ngclearn .components .jaxComponent import JaxComponent
6
7
from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
7
8
step_euler , step_rk2
8
- from ngclearn .utils .surrogate_fx import secant_lif_estimator , arctan_estimator , triangular_estimator
9
+ from ngclearn .utils .surrogate_fx import (secant_lif_estimator , arctan_estimator ,
10
+ triangular_estimator ,
11
+ straight_through_estimator )
9
12
10
13
@jit
11
14
def _update_times (t , s , tols ):
@@ -25,12 +28,6 @@ def _update_times(t, s, tols):
25
28
_tols = (1. - s ) * tols + (s * t )
26
29
return _tols
27
30
28
- # @jit
29
- # def _modify_current(j, dt, tau_m, R_m):
30
- # ## electrical current re-scaling co-routine
31
- # jScale = tau_m/dt ## <-- this anti-scale counter-balances form of ODE used in this cell
32
- # return (j * R_m) * jScale
33
-
34
31
@jit
35
32
def _dfv_internal (j , v , rfr , tau_m , refract_T , v_rest , v_decay = 1. ): ## raw voltage dynamics
36
33
mask = (rfr >= refract_T ).astype (jnp .float32 ) # get refractory mask
@@ -47,44 +44,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
47
44
#@partial(jit, static_argnums=[7, 8, 9, 10, 11, 12])
48
45
def _run_cell (dt , j , v , v_thr , v_theta , rfr , skey , tau_m , v_rest , v_reset ,
49
46
v_decay , refract_T , integType = 0 ):
50
- """
51
- Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
52
-
53
- Args:
54
- dt: integration time constant (milliseconds, or ms)
55
-
56
- j: electrical current value
57
-
58
- v: membrane potential (voltage, in milliVolts or mV) value (at t)
59
-
60
- v_thr: base voltage threshold value (in mV)
61
-
62
- v_theta: threshold shift (homeostatic) variable (at t)
63
-
64
- rfr: refractory variable vector (one per neuronal cell)
65
-
66
- skey: PRNG key which, if not None, will trigger a single-spike constraint
67
- (i.e., only one spike permitted to emit per single step of time);
68
- specifically used to randomly sample one of the possible action
69
- potentials to be an emitted spike
70
-
71
- tau_m: cell membrane time constant
72
-
73
- v_rest: membrane resting potential (in mV)
74
-
75
- v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
76
- a neuronal cell's membrane potential will be set to this value
77
-
78
- v_decay: strength of voltage leak (Default: 1.)
79
-
80
- refract_T: (relative) refractory time period (in ms; Default
81
- value is 1 ms)
82
-
83
- integType: integer indicating type of integration to use
84
-
85
- Returns:
86
- voltage(t+dt), spikes, raw spikes, updated refactory variables
87
- """
47
+ ### Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
88
48
_v_thr = v_theta + v_thr ## calc present voltage threshold
89
49
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
90
50
## update voltage / membrane potential
@@ -114,24 +74,7 @@ def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset,
114
74
115
75
@partial (jit , static_argnums = [3 , 4 ])
116
76
def _update_theta (dt , v_theta , s , tau_theta , theta_plus = 0.05 ):
117
- """
118
- Runs homeostatic threshold update dynamics one step (via Euler integration).
119
-
120
- Args:
121
- dt: integration time constant (milliseconds, or ms)
122
-
123
- v_theta: current value of homeostatic threshold variable
124
-
125
- s: current spikes (at t)
126
-
127
- tau_theta: homeostatic threshold time constant
128
-
129
- theta_plus: physical increment to be applied to any threshold value if
130
- a spike was emitted
131
-
132
- Returns:
133
- updated homeostatic threshold variable
134
- """
77
+ ### Runs homeostatic threshold update dynamics one step (via Euler integration).
135
78
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
136
79
#theta_plus = 0.05
137
80
#_V_theta = V_theta * theta_decay + S * theta_plus
@@ -205,11 +148,10 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
205
148
at an increase in computational cost (and simulation time)
206
149
"""
207
150
208
- # Define Functions
151
+ @ deprecate_args ( thr_jitter = None )
209
152
def __init__ (self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. ,
210
153
v_reset = - 60. , v_decay = 1. , tau_theta = 1e7 , theta_plus = 0.05 ,
211
- refract_time = 5. , thr_jitter = 0. , one_spike = False ,
212
- integration_type = "euler" , ** kwargs ):
154
+ refract_time = 5. , one_spike = False , integration_type = "euler" , ** kwargs ):
213
155
super ().__init__ (name , ** kwargs )
214
156
215
157
## Integration properties
@@ -218,15 +160,15 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
218
160
219
161
## membrane parameter setup (affects ODE integration)
220
162
self .tau_m = tau_m ## membrane time constant
221
- self .R_m = resist_m ## resistance value
163
+ self .resist_m = resist_m ## resistance value
222
164
self .one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
223
165
224
166
self .v_rest = v_rest #-65. # mV
225
167
self .v_reset = v_reset # -60. # -65. # mV (milli-volts)
226
168
self .v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF)
227
169
## basic asserts to prevent neuronal dynamics breaking...
228
170
#assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify...
229
- assert self .R_m > 0.
171
+ assert self .resist_m > 0.
230
172
self .tau_theta = tau_theta ## threshold time constant # ms (0 turns off)
231
173
self .theta_plus = theta_plus #0.05 ## threshold increment
232
174
self .refract_T = refract_time #5. # 2. ## refractory period # ms
@@ -237,41 +179,45 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
237
179
self .n_units = n_units
238
180
239
181
## set up surrogate function for spike emission
240
- self .spike_fx , self .d_spike_fx = secant_lif_estimator ()
241
- #self.spike_fx, self.d_spike_fx = arctan_estimator() #
242
- #self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()
182
+ surrgoate_type = "secant_lif"
183
+ if surrgoate_type == "secant_lif" :
184
+ self .spike_fx , self .d_spike_fx = secant_lif_estimator ()
185
+ elif surrgoate_type == "arctan" :
186
+ self .spike_fx , self .d_spike_fx = arctan_estimator ()
187
+ elif surrgoate_type == "triangular" :
188
+ self .spike_fx , self .d_spike_fx = triangular_estimator ()
189
+ else : ## default is the straight-through estimator (STE)
190
+ self .spike_fx , self .d_spike_fx = straight_through_estimator ()
191
+
243
192
244
193
## Compartment setup
245
194
restVals = jnp .zeros ((self .batch_size , self .n_units ))
246
- thr0 = 0.
247
- if thr_jitter > 0. :
248
- key , subkey = random .split (self .key .value )
249
- thr0 = random .uniform (subkey , (1 , n_units ), minval = - thr_jitter ,
250
- maxval = thr_jitter , dtype = jnp .float32 )
251
- self .j = Compartment (restVals )
252
- self .v = Compartment (restVals + self .v_rest )
253
- self .s = Compartment (restVals )
254
- self .s_raw = Compartment (restVals )
255
- self .rfr = Compartment (restVals + self .refract_T )
256
- self .thr_theta = Compartment (restVals + thr0 )
257
- self .tols = Compartment (restVals ) ## time-of-last-spike
258
- self .surrogate = Compartment (restVals + 1. ) ## surrogate signal
195
+ self .j = Compartment (restVals , display_name = "Current" , units = "mA" )
196
+ self .v = Compartment (restVals + self .v_rest ,
197
+ display_name = "Voltage" , units = "mV" )
198
+ self .s = Compartment (restVals , display_name = "Spikes" )
199
+ self .s_raw = Compartment (restVals , display_name = "Raw Spike Pulses" )
200
+ self .rfr = Compartment (restVals + self .refract_T ,
201
+ display_name = "Refractory Time Period" , units = "ms" )
202
+ self .thr_theta = Compartment (restVals , display_name = "Threshold Adaptive Shift" ,
203
+ units = "mV" )
204
+ self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" ,
205
+ units = "ms" ) ## time-of-last-spike
206
+ self .surrogate = Compartment (restVals + 1. , display_name = "Surrogate State Value" )
259
207
260
208
@staticmethod
261
- def _advance_state (t , dt , tau_m , R_m , v_rest , v_reset , v_decay , refract_T ,
209
+ def _advance_state (t , dt , tau_m , resist_m , v_rest , v_reset , v_decay , refract_T ,
262
210
thr , tau_theta , theta_plus , one_spike , intgFlag , d_spike_fx ,
263
211
key , j , v , s , rfr , thr_theta , tols ):
264
212
skey = None ## this is an empty dkey if single_spike mode turned off
265
213
if one_spike :
266
214
key , skey = random .split (key , 2 )
267
215
## run one integration step for neuronal dynamics
268
- j = j * R_m
269
- #surrogate = d_spike_fx(v, thr + thr_theta)
216
+ j = j * resist_m
270
217
v , s , raw_spikes , rfr = _run_cell (dt , j , v , thr , thr_theta , rfr , skey ,
271
- tau_m , v_rest , v_reset , v_decay , refract_T ,
272
- intgFlag )
218
+ tau_m , v_rest , v_reset , v_decay ,
219
+ refract_T , intgFlag )
273
220
surrogate = d_spike_fx (v , thr + thr_theta )
274
- #surrogate = d_spike_fx(j, thr + thr_theta)
275
221
if tau_theta > 0. :
276
222
## run one integration step for threshold dynamics
277
223
thr_theta = _update_theta (dt , thr_theta , raw_spikes , tau_theta , theta_plus )
@@ -310,22 +256,53 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
310
256
self .s .set (s )
311
257
self .s_raw .set (s_raw )
312
258
self .rfr .set (rfr )
313
- #self.thr_theta.set(thr_theta)
314
259
self .tols .set (tols )
315
260
self .surrogate .set (surrogate )
316
261
317
262
def save (self , directory , ** kwargs ):
263
+ ## do a protected save of constants, depending on whether they are floats or arrays
264
+ tau_m = (self .tau_m if isinstance (self .tau_m , float )
265
+ else jnp .ones ([[self .tau_m ]]))
266
+ thr = (self .thr if isinstance (self .thr , float )
267
+ else jnp .ones ([[self .thr ]]))
268
+ v_rest = (self .v_rest if isinstance (self .v_rest , float )
269
+ else jnp .ones ([[self .v_rest ]]))
270
+ v_reset = (self .v_reset if isinstance (self .v_reset , float )
271
+ else jnp .ones ([[self .v_reset ]]))
272
+ v_decay = (self .v_decay if isinstance (self .v_decay , float )
273
+ else jnp .ones ([[self .v_decay ]]))
274
+ resist_m = (self .resist_m if isinstance (self .resist_m , float )
275
+ else jnp .ones ([[self .resist_m ]]))
276
+ tau_theta = (self .tau_theta if isinstance (self .tau_theta , float )
277
+ else jnp .ones ([[self .tau_theta ]]))
278
+ theta_plus = (self .theta_plus if isinstance (self .theta_plus , float )
279
+ else jnp .ones ([[self .theta_plus ]]))
280
+
318
281
file_name = directory + "/" + self .name + ".npz"
319
282
jnp .savez (file_name ,
320
283
threshold_theta = self .thr_theta .value ,
284
+ tau_m = tau_m , thr = thr , v_rest = v_rest ,
285
+ v_reset = v_reset , v_decay = v_decay ,
286
+ resist_m = resist_m , tau_theta = tau_theta ,
287
+ theta_plus = theta_plus ,
321
288
key = self .key .value )
322
289
323
290
def load (self , directory , seeded = False , ** kwargs ):
324
291
file_name = directory + "/" + self .name + ".npz"
325
292
data = jnp .load (file_name )
326
- self .thr_theta .set ( data ['threshold_theta' ] )
327
- if seeded == True :
328
- self .key .set ( data ['key' ] )
293
+ self .thr_theta .set (data ['thr_theta' ])
294
+ ## constants loaded in
295
+ self .tau_m = data ['tau_m' ]
296
+ self .thr = data ['thr' ]
297
+ self .v_rest = data ['v_rest' ]
298
+ self .v_reset = data ['v_reset' ]
299
+ self .v_decay = data ['v_decay' ]
300
+ self .resist_m = data ['resist_m' ]
301
+ self .tau_theta = data ['tau_theta' ]
302
+ self .theta_plus = data ['theta_plus' ]
303
+
304
+ if seeded :
305
+ self .key .set (data ['key' ])
329
306
330
307
@classmethod
331
308
def help (cls ): ## component help function
@@ -356,7 +333,6 @@ def help(cls): ## component help function
356
333
"tau_theta" : "Threshold/homoestatic increment time constant" ,
357
334
"theta_plus" : "Amount to increment threshold by upon occurrence of spike" ,
358
335
"refract_time" : "Length of relative refractory period (ms)" ,
359
- "thr_jitter" : "Scale of random uniform noise to apply to initial condition of threshold" ,
360
336
"one_spike" : "Should only one spike be sampled/allowed to emit at any given time step?" ,
361
337
"integration_type" : "Type of numerical integration to use for the cell dynamics"
362
338
}
0 commit comments