Skip to content

Commit cdea291

Browse files
committed
Merge branch 'dynamics' of github.com:NACLab/ngc-learn into dynamics
2 parents 05a97f0 + 27a61ef commit cdea291

File tree

8 files changed

+369
-19
lines changed

8 files changed

+369
-19
lines changed

Diff for: ngclearn/components/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .neurons.graded.rewardErrorCell import RewardErrorCell
77
## point to standard spiking cell component types
88
from .neurons.spiking.sLIFCell import SLIFCell
9+
from .neurons.spiking.IFCell import IFCell
910
from .neurons.spiking.LIFCell import LIFCell
1011
from .neurons.spiking.WTASCell import WTASCell
1112
from .neurons.spiking.quadLIFCell import QuadLIFCell

Diff for: ngclearn/components/input_encoders/bernoulliCell.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from jax import numpy as jnp, random, jit
44
from ngclearn.utils import tensorstats
55
from functools import partial
6+
from ngcsimlib.deprecators import deprecate_args
7+
from ngcsimlib.logger import info, warn
68

79
@jit
810
def _update_times(t, s, tols):
@@ -79,15 +81,15 @@ class BernoulliCell(JaxComponent):
7981
8082
n_units: number of cellular entities (neural population size)
8183
82-
max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
84+
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
8385
"""
8486

85-
# Define Functions
86-
def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
87+
@deprecate_args(target_freq="max_freq")
88+
def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
8789
super().__init__(name, **kwargs)
8890

8991
## Constrained Bernoulli meta-parameters
90-
self.max_freq = max_freq ## maximum frequency (in Hertz/Hz)
92+
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
9193

9294
## Layer Size Setup
9395
self.batch_size = batch_size
@@ -99,12 +101,26 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
99101
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
100102
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
101103

104+
def validate(self, dt, **validation_kwargs):
105+
## check for unstable combinations of dt and target-frequency meta-params
106+
valid = super().validate(**validation_kwargs)
107+
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
108+
if events_per_timestep > 1.:
109+
valid = False
110+
warn(
111+
f"{self.name} will be unable to make as many temporal events as "
112+
f"requested! ({events_per_timestep} events/timestep) Unstable "
113+
f"combination of dt = {dt} and target_freq = {self.target_freq} "
114+
f"being used!"
115+
)
116+
return valid
117+
102118
@staticmethod
103-
def _advance_state(t, dt, max_freq, key, inputs, tols):
119+
def _advance_state(t, dt, target_freq, key, inputs, tols):
104120
key, *subkeys = random.split(key, 2)
105-
if max_freq > 0.:
121+
if target_freq > 0.:
106122
outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate
107-
subkeys[0], data=inputs, dt=dt, fmax=max_freq
123+
subkeys[0], data=inputs, dt=dt, fmax=target_freq
108124
)
109125
else:
110126
outputs = _sample_bernoulli(subkeys[0], data=inputs)

Diff for: ngclearn/components/input_encoders/latencyCell.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ngclearn.utils.model_utils import clamp_min, clamp_max
55
from jax import numpy as jnp, random, jit
66
from functools import partial
7+
from ngcsimlib.logger import info
78

89
@jit
910
def _update_times(t, s, tols):

Diff for: ngclearn/components/jaxComponent.py

-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@ def __init__(self, name, key=None, directory=None, **kwargs):
2121
self.directory = directory
2222
self.key = Compartment(
2323
random.PRNGKey(time.time_ns()) if key is None else key)
24-

Diff for: ngclearn/components/neurons/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .graded.rewardErrorCell import RewardErrorCell
66
## point to standard spiking cell component types
77
from .spiking.sLIFCell import SLIFCell
8+
from .spiking.IFCell import IFCell
89
from .spiking.LIFCell import LIFCell
910
from .spiking.WTASCell import WTASCell
1011
from .spiking.quadLIFCell import QuadLIFCell

Diff for: ngclearn/components/neurons/spiking/IFCell.py

+306
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
from jax import numpy as jnp, random, jit, nn
2+
from ngclearn.utils import tensorstats
3+
from ngcsimlib.deprecators import deprecate_args
4+
from ngclearn import resolver, Component, Compartment
5+
from ngclearn.components.jaxComponent import JaxComponent
6+
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
7+
step_euler, step_rk2
8+
from ngclearn.utils.surrogate_fx import (arctan_estimator,
9+
triangular_estimator,
10+
straight_through_estimator)
11+
12+
@jit
13+
def _update_times(t, s, tols):
14+
"""
15+
Updates time-of-last-spike (tols) variable.
16+
17+
Args:
18+
t: current time (a scalar/int value)
19+
20+
s: binary spike vector
21+
22+
tols: current time-of-last-spike variable
23+
24+
Returns:
25+
updated tols variable
26+
"""
27+
_tols = (1. - s) * tols + (s * t)
28+
return _tols
29+
30+
@jit
31+
def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
32+
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
33+
## update voltage / membrane potential
34+
dv_dt = (j * mask) ## integration only involves electrical current
35+
dv_dt = dv_dt * (1./tau_m)
36+
return dv_dt
37+
38+
def _dfv(t, v, params): ## voltage dynamics wrapper
39+
j, rfr, tau_m, refract_T = params
40+
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T)
41+
return dv_dt
42+
43+
def _run_cell(dt, j, v, v_thr, rfr, tau_m, v_rest, v_reset, refract_T, integType=0):
44+
### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
45+
## update voltage / membrane potential
46+
v_params = (j, rfr, tau_m, refract_T)
47+
if integType == 1:
48+
_, _v = step_rk2(0., v, _dfv, dt, v_params)
49+
else:
50+
_, _v = step_euler(0., v, _dfv, dt, v_params)
51+
## obtain action potentials/spikes
52+
s = (_v > v_thr).astype(jnp.float32)
53+
## update refractory variables
54+
_rfr = (rfr + dt) * (1. - s)
55+
## perform hyper-polarization of neuronal cells
56+
_v = _v * (1. - s) + s * v_reset
57+
return _v, s, _rfr
58+
59+
class IFCell(JaxComponent): ## integrate-and-fire cell
60+
"""
61+
A spiking cell based on integrate-and-fire (IF) neuronal dynamics.
62+
63+
The specific differential equation that characterizes this cell
64+
is (for adjusting v, given current j, over time) is:
65+
66+
| tau_m * dv/dt = (v_rest - v) + j * R
67+
| where R is the membrane resistance and v_rest is the resting potential
68+
| also, if a spike occurs, v is set to v_reset
69+
70+
| --- Cell Input Compartments: ---
71+
| j - electrical current input (takes in external signals)
72+
| --- Cell State Compartments: ---
73+
| v - membrane potential/voltage state
74+
| rfr - (relative) refractory variable state
75+
| key - JAX PRNG key
76+
| --- Cell Output Compartments: ---
77+
| s - emitted binary spikes/action potentials
78+
| s_raw - raw spike signals before post-processing (only if one_spike = True, else s_raw = s)
79+
| tols - time-of-last-spike
80+
81+
Args:
82+
name: the string name of this cell
83+
84+
n_units: number of cellular entities (neural population size)
85+
86+
tau_m: membrane time constant
87+
88+
resist_m: membrane resistance value (default: 1)
89+
90+
thr: base value for adaptive thresholds that govern short-term
91+
plasticity (in milliVolts, or mV; default: -52. mV)
92+
93+
v_rest: membrane resting potential (in mV; default: -65 mV)
94+
95+
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
96+
a neuronal cell's membrane potential will be set to this value;
97+
(default: -60 mV)
98+
99+
refract_time: relative refractory period time (ms; default: 0 ms)
100+
101+
integration_type: type of integration to use for this cell's dynamics;
102+
current supported forms include "euler" (Euler/RK-1 integration)
103+
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
104+
105+
:Note: setting the integration type to the midpoint method will
106+
increase the accuray of the estimate of the cell's evolution
107+
at an increase in computational cost (and simulation time)
108+
109+
surrgoate_type: type of surrogate function to use for approximating a
110+
partial derivative of this cell's spikes w.r.t. its voltage/current
111+
(default: "straight_through")
112+
113+
:Note: surrogate options available include: "straight_through"
114+
(straight-through estimator), "triangular" (triangular estimator),
115+
and "arctan" (arc-tangent estimator)
116+
117+
lower_clamp_voltage: if True, this will ensure voltage never is below
118+
the value of `v_rest` (default: True)
119+
"""
120+
121+
@deprecate_args(thr_jitter=None)
122+
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
123+
v_reset=-60., refract_time=0., integration_type="euler",
124+
surrgoate_type="straight_through", lower_clamp_voltage=True,
125+
**kwargs):
126+
super().__init__(name, **kwargs)
127+
128+
## Integration properties
129+
self.integrationType = integration_type
130+
self.intgFlag = get_integrator_code(self.integrationType)
131+
132+
## membrane parameter setup (affects ODE integration)
133+
self.tau_m = tau_m ## membrane time constant
134+
self.resist_m = resist_m ## resistance value
135+
136+
self.v_rest = v_rest #-65. # mV
137+
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
138+
## basic asserts to prevent neuronal dynamics breaking...
139+
assert self.resist_m > 0.
140+
self.refract_T = refract_time #5. # 2. ## refractory period # ms
141+
self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV
142+
self.lower_clamp_voltage = lower_clamp_voltage
143+
144+
## Layer Size Setup
145+
self.batch_size = 1
146+
self.n_units = n_units
147+
148+
## set up surrogate function for spike emission
149+
if surrgoate_type == "arctan":
150+
self.spike_fx, self.d_spike_fx = arctan_estimator()
151+
elif surrgoate_type == "triangular":
152+
self.spike_fx, self.d_spike_fx = triangular_estimator()
153+
else: ## default: straight_through
154+
self.spike_fx, self.d_spike_fx = straight_through_estimator()
155+
156+
157+
## Compartment setup
158+
restVals = jnp.zeros((self.batch_size, self.n_units))
159+
self.j = Compartment(restVals, display_name="Current", units="mA")
160+
self.v = Compartment(restVals + self.v_rest,
161+
display_name="Voltage", units="mV")
162+
self.s = Compartment(restVals, display_name="Spikes")
163+
self.rfr = Compartment(restVals + self.refract_T,
164+
display_name="Refractory Time Period", units="ms")
165+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
166+
units="ms") ## time-of-last-spike
167+
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
168+
169+
@staticmethod
170+
def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, refract_T,
171+
thr, lower_clamp_voltage, intgFlag, d_spike_fx, key,
172+
j, v, rfr, tols):
173+
## run one integration step for neuronal dynamics
174+
j = j * resist_m
175+
v, s, rfr = _run_cell(dt, j, v, thr, rfr, tau_m, v_rest, v_reset,
176+
refract_T, intgFlag)
177+
surrogate = d_spike_fx(v, thr)
178+
## update tols
179+
tols = _update_times(t, s, tols)
180+
if lower_clamp_voltage: ## ensure voltage never < v_rest
181+
v = jnp.maximum(v, v_rest)
182+
return v, s, rfr, tols, key, surrogate
183+
184+
@resolver(_advance_state)
185+
def advance_state(self, v, s, rfr, tols, key, surrogate):
186+
self.v.set(v)
187+
self.s.set(s)
188+
self.rfr.set(rfr)
189+
self.tols.set(tols)
190+
self.key.set(key)
191+
self.surrogate.set(surrogate)
192+
193+
@staticmethod
194+
def _reset(batch_size, n_units, v_rest, refract_T):
195+
restVals = jnp.zeros((batch_size, n_units))
196+
j = restVals #+ 0
197+
v = restVals + v_rest
198+
s = restVals #+ 0
199+
rfr = restVals + refract_T
200+
tols = restVals #+ 0
201+
surrogate = restVals + 1.
202+
return j, v, s, rfr, tols, surrogate
203+
204+
@resolver(_reset)
205+
def reset(self, j, v, s, rfr, tols, surrogate):
206+
self.j.set(j)
207+
self.v.set(v)
208+
self.s.set(s)
209+
self.rfr.set(rfr)
210+
self.tols.set(tols)
211+
self.surrogate.set(surrogate)
212+
213+
def save(self, directory, **kwargs):
214+
## do a protected save of constants, depending on whether they are floats or arrays
215+
tau_m = (self.tau_m if isinstance(self.tau_m, float)
216+
else jnp.ones([[self.tau_m]]))
217+
thr = (self.thr if isinstance(self.thr, float)
218+
else jnp.ones([[self.thr]]))
219+
v_rest = (self.v_rest if isinstance(self.v_rest, float)
220+
else jnp.ones([[self.v_rest]]))
221+
v_reset = (self.v_reset if isinstance(self.v_reset, float)
222+
else jnp.ones([[self.v_reset]]))
223+
v_decay = (self.v_decay if isinstance(self.v_decay, float)
224+
else jnp.ones([[self.v_decay]]))
225+
resist_m = (self.resist_m if isinstance(self.resist_m, float)
226+
else jnp.ones([[self.resist_m]]))
227+
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
228+
else jnp.ones([[self.tau_theta]]))
229+
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
230+
else jnp.ones([[self.theta_plus]]))
231+
232+
file_name = directory + "/" + self.name + ".npz"
233+
jnp.savez(file_name,
234+
tau_m=tau_m, thr=thr, v_rest=v_rest,
235+
v_reset=v_reset, v_decay=v_decay,
236+
resist_m=resist_m, tau_theta=tau_theta,
237+
theta_plus=theta_plus,
238+
key=self.key.value)
239+
240+
def load(self, directory, seeded=False, **kwargs):
241+
file_name = directory + "/" + self.name + ".npz"
242+
data = jnp.load(file_name)
243+
## constants loaded in
244+
self.tau_m = data['tau_m']
245+
self.thr = data['thr']
246+
self.v_rest = data['v_rest']
247+
self.v_reset = data['v_reset']
248+
self.v_decay = data['v_decay']
249+
self.resist_m = data['resist_m']
250+
self.tau_theta = data['tau_theta']
251+
self.theta_plus = data['theta_plus']
252+
253+
if seeded:
254+
self.key.set(data['key'])
255+
256+
@classmethod
257+
def help(cls): ## component help function
258+
properties = {
259+
"cell_type": "IFCell - evolves neurons according to integrate-"
260+
"and-fire spiking dynamics."
261+
}
262+
compartment_props = {
263+
"inputs":
264+
{"j": "External input electrical current"},
265+
"states":
266+
{"v": "Membrane potential/voltage at time t",
267+
"rfr": "Current state of (relative) refractory variable",
268+
"thr": "Current state of voltage threshold at time t",
269+
"key": "JAX PRNG key"},
270+
"outputs":
271+
{"s": "Emitted spikes/pulses at time t",
272+
"tols": "Time-of-last-spike"},
273+
}
274+
hyperparams = {
275+
"n_units": "Number of neuronal cells to model in this layer",
276+
"tau_m": "Cell membrane time constant",
277+
"resist_m": "Membrane resistance value",
278+
"thr": "Base voltage threshold value",
279+
"v_rest": "Resting membrane potential value",
280+
"v_reset": "Reset membrane potential value",
281+
"refract_time": "Length of relative refractory period (ms)",
282+
"integration_type": "Type of numerical integration to use for the cell dynamics",
283+
"surrgoate_type": "Type of surrogate function to use approximate "
284+
"derivative of spike w.r.t. voltage/current",
285+
"lower_bound_clamp": "Should voltage be lower bounded to be never be below `v_rest`"
286+
}
287+
info = {cls.__name__: properties,
288+
"compartments": compartment_props,
289+
"dynamics": "tau_m * dv/dt = (v_rest - v) + j * resist_m",
290+
"hyperparameters": hyperparams}
291+
return info
292+
293+
def __repr__(self):
294+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
295+
maxlen = max(len(c) for c in comps) + 5
296+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
297+
for c in comps:
298+
stats = tensorstats(getattr(self, c).value)
299+
if stats is not None:
300+
line = [f"{k}: {v}" for k, v in stats.items()]
301+
line = ", ".join(line)
302+
else:
303+
line = "None"
304+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
305+
return lines
306+

0 commit comments

Comments
 (0)