Skip to content

Commit cbff94b

Browse files
committed
Merge branch 'main' into monitor_plot
2 parents b4bff5e + 1ddd86d commit cbff94b

File tree

1 file changed

+71
-95
lines changed

1 file changed

+71
-95
lines changed

ngclearn/components/neurons/spiking/LIFCell.py

+71-95
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from jax import numpy as jnp, random, jit, nn
22
from functools import partial
33
from ngclearn.utils import tensorstats
4+
from ngcsimlib.deprecators import deprecate_args
45
from ngclearn import resolver, Component, Compartment
56
from ngclearn.components.jaxComponent import JaxComponent
67
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
78
step_euler, step_rk2
8-
from ngclearn.utils.surrogate_fx import secant_lif_estimator, arctan_estimator, triangular_estimator
9+
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10+
triangular_estimator,
11+
straight_through_estimator)
912

1013
@jit
1114
def _update_times(t, s, tols):
@@ -25,12 +28,6 @@ def _update_times(t, s, tols):
2528
_tols = (1. - s) * tols + (s * t)
2629
return _tols
2730

28-
# @jit
29-
# def _modify_current(j, dt, tau_m, R_m):
30-
# ## electrical current re-scaling co-routine
31-
# jScale = tau_m/dt ## <-- this anti-scale counter-balances form of ODE used in this cell
32-
# return (j * R_m) * jScale
33-
3431
@jit
3532
def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
3633
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
@@ -47,44 +44,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
4744
#@partial(jit, static_argnums=[7, 8, 9, 10, 11, 12])
4845
def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset,
4946
v_decay, refract_T, integType=0):
50-
"""
51-
Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
52-
53-
Args:
54-
dt: integration time constant (milliseconds, or ms)
55-
56-
j: electrical current value
57-
58-
v: membrane potential (voltage, in milliVolts or mV) value (at t)
59-
60-
v_thr: base voltage threshold value (in mV)
61-
62-
v_theta: threshold shift (homeostatic) variable (at t)
63-
64-
rfr: refractory variable vector (one per neuronal cell)
65-
66-
skey: PRNG key which, if not None, will trigger a single-spike constraint
67-
(i.e., only one spike permitted to emit per single step of time);
68-
specifically used to randomly sample one of the possible action
69-
potentials to be an emitted spike
70-
71-
tau_m: cell membrane time constant
72-
73-
v_rest: membrane resting potential (in mV)
74-
75-
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
76-
a neuronal cell's membrane potential will be set to this value
77-
78-
v_decay: strength of voltage leak (Default: 1.)
79-
80-
refract_T: (relative) refractory time period (in ms; Default
81-
value is 1 ms)
82-
83-
integType: integer indicating type of integration to use
84-
85-
Returns:
86-
voltage(t+dt), spikes, raw spikes, updated refactory variables
87-
"""
47+
### Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
8848
_v_thr = v_theta + v_thr ## calc present voltage threshold
8949
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
9050
## update voltage / membrane potential
@@ -114,24 +74,7 @@ def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset,
11474

11575
@partial(jit, static_argnums=[3, 4])
11676
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
117-
"""
118-
Runs homeostatic threshold update dynamics one step (via Euler integration).
119-
120-
Args:
121-
dt: integration time constant (milliseconds, or ms)
122-
123-
v_theta: current value of homeostatic threshold variable
124-
125-
s: current spikes (at t)
126-
127-
tau_theta: homeostatic threshold time constant
128-
129-
theta_plus: physical increment to be applied to any threshold value if
130-
a spike was emitted
131-
132-
Returns:
133-
updated homeostatic threshold variable
134-
"""
77+
### Runs homeostatic threshold update dynamics one step (via Euler integration).
13578
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
13679
#theta_plus = 0.05
13780
#_V_theta = V_theta * theta_decay + S * theta_plus
@@ -205,11 +148,10 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
205148
at an increase in computational cost (and simulation time)
206149
"""
207150

208-
# Define Functions
151+
@deprecate_args(thr_jitter=None)
209152
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
210153
v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05,
211-
refract_time=5., thr_jitter=0., one_spike=False,
212-
integration_type="euler", **kwargs):
154+
refract_time=5., one_spike=False, integration_type="euler", **kwargs):
213155
super().__init__(name, **kwargs)
214156

215157
## Integration properties
@@ -218,15 +160,15 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
218160

219161
## membrane parameter setup (affects ODE integration)
220162
self.tau_m = tau_m ## membrane time constant
221-
self.R_m = resist_m ## resistance value
163+
self.resist_m = resist_m ## resistance value
222164
self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
223165

224166
self.v_rest = v_rest #-65. # mV
225167
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
226168
self.v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF)
227169
## basic asserts to prevent neuronal dynamics breaking...
228170
#assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify...
229-
assert self.R_m > 0.
171+
assert self.resist_m > 0.
230172
self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off)
231173
self.theta_plus = theta_plus #0.05 ## threshold increment
232174
self.refract_T = refract_time #5. # 2. ## refractory period # ms
@@ -237,41 +179,45 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
237179
self.n_units = n_units
238180

239181
## set up surrogate function for spike emission
240-
self.spike_fx, self.d_spike_fx = secant_lif_estimator()
241-
#self.spike_fx, self.d_spike_fx = arctan_estimator() #
242-
#self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()
182+
surrgoate_type = "secant_lif"
183+
if surrgoate_type == "secant_lif":
184+
self.spike_fx, self.d_spike_fx = secant_lif_estimator()
185+
elif surrgoate_type == "arctan":
186+
self.spike_fx, self.d_spike_fx = arctan_estimator()
187+
elif surrgoate_type == "triangular":
188+
self.spike_fx, self.d_spike_fx = triangular_estimator()
189+
else: ## default is the straight-through estimator (STE)
190+
self.spike_fx, self.d_spike_fx = straight_through_estimator()
191+
243192

244193
## Compartment setup
245194
restVals = jnp.zeros((self.batch_size, self.n_units))
246-
thr0 = 0.
247-
if thr_jitter > 0.:
248-
key, subkey = random.split(self.key.value)
249-
thr0 = random.uniform(subkey, (1, n_units), minval=-thr_jitter,
250-
maxval=thr_jitter, dtype=jnp.float32)
251-
self.j = Compartment(restVals)
252-
self.v = Compartment(restVals + self.v_rest)
253-
self.s = Compartment(restVals)
254-
self.s_raw = Compartment(restVals)
255-
self.rfr = Compartment(restVals + self.refract_T)
256-
self.thr_theta = Compartment(restVals + thr0)
257-
self.tols = Compartment(restVals) ## time-of-last-spike
258-
self.surrogate = Compartment(restVals + 1.) ## surrogate signal
195+
self.j = Compartment(restVals, display_name="Current", units="mA")
196+
self.v = Compartment(restVals + self.v_rest,
197+
display_name="Voltage", units="mV")
198+
self.s = Compartment(restVals, display_name="Spikes")
199+
self.s_raw = Compartment(restVals, display_name="Raw Spike Pulses")
200+
self.rfr = Compartment(restVals + self.refract_T,
201+
display_name="Refractory Time Period", units="ms")
202+
self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift",
203+
units="mV")
204+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
205+
units="ms") ## time-of-last-spike
206+
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
259207

260208
@staticmethod
261-
def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
209+
def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T,
262210
thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx,
263211
key, j, v, s, rfr, thr_theta, tols):
264212
skey = None ## this is an empty dkey if single_spike mode turned off
265213
if one_spike:
266214
key, skey = random.split(key, 2)
267215
## run one integration step for neuronal dynamics
268-
j = j * R_m
269-
#surrogate = d_spike_fx(v, thr + thr_theta)
216+
j = j * resist_m
270217
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
271-
tau_m, v_rest, v_reset, v_decay, refract_T,
272-
intgFlag)
218+
tau_m, v_rest, v_reset, v_decay,
219+
refract_T, intgFlag)
273220
surrogate = d_spike_fx(v, thr + thr_theta)
274-
#surrogate = d_spike_fx(j, thr + thr_theta)
275221
if tau_theta > 0.:
276222
## run one integration step for threshold dynamics
277223
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
@@ -310,22 +256,53 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
310256
self.s.set(s)
311257
self.s_raw.set(s_raw)
312258
self.rfr.set(rfr)
313-
#self.thr_theta.set(thr_theta)
314259
self.tols.set(tols)
315260
self.surrogate.set(surrogate)
316261

317262
def save(self, directory, **kwargs):
263+
## do a protected save of constants, depending on whether they are floats or arrays
264+
tau_m = (self.tau_m if isinstance(self.tau_m, float)
265+
else jnp.ones([[self.tau_m]]))
266+
thr = (self.thr if isinstance(self.thr, float)
267+
else jnp.ones([[self.thr]]))
268+
v_rest = (self.v_rest if isinstance(self.v_rest, float)
269+
else jnp.ones([[self.v_rest]]))
270+
v_reset = (self.v_reset if isinstance(self.v_reset, float)
271+
else jnp.ones([[self.v_reset]]))
272+
v_decay = (self.v_decay if isinstance(self.v_decay, float)
273+
else jnp.ones([[self.v_decay]]))
274+
resist_m = (self.resist_m if isinstance(self.resist_m, float)
275+
else jnp.ones([[self.resist_m]]))
276+
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
277+
else jnp.ones([[self.tau_theta]]))
278+
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
279+
else jnp.ones([[self.theta_plus]]))
280+
318281
file_name = directory + "/" + self.name + ".npz"
319282
jnp.savez(file_name,
320283
threshold_theta=self.thr_theta.value,
284+
tau_m=tau_m, thr=thr, v_rest=v_rest,
285+
v_reset=v_reset, v_decay=v_decay,
286+
resist_m=resist_m, tau_theta=tau_theta,
287+
theta_plus=theta_plus,
321288
key=self.key.value)
322289

323290
def load(self, directory, seeded=False, **kwargs):
324291
file_name = directory + "/" + self.name + ".npz"
325292
data = jnp.load(file_name)
326-
self.thr_theta.set( data['threshold_theta'] )
327-
if seeded == True:
328-
self.key.set( data['key'] )
293+
self.thr_theta.set(data['thr_theta'])
294+
## constants loaded in
295+
self.tau_m = data['tau_m']
296+
self.thr = data['thr']
297+
self.v_rest = data['v_rest']
298+
self.v_reset = data['v_reset']
299+
self.v_decay = data['v_decay']
300+
self.resist_m = data['resist_m']
301+
self.tau_theta = data['tau_theta']
302+
self.theta_plus = data['theta_plus']
303+
304+
if seeded:
305+
self.key.set(data['key'])
329306

330307
@classmethod
331308
def help(cls): ## component help function
@@ -356,7 +333,6 @@ def help(cls): ## component help function
356333
"tau_theta": "Threshold/homoestatic increment time constant",
357334
"theta_plus": "Amount to increment threshold by upon occurrence of spike",
358335
"refract_time": "Length of relative refractory period (ms)",
359-
"thr_jitter": "Scale of random uniform noise to apply to initial condition of threshold",
360336
"one_spike": "Should only one spike be sampled/allowed to emit at any given time step?",
361337
"integration_type": "Type of numerical integration to use for the cell dynamics"
362338
}

0 commit comments

Comments
 (0)