Skip to content

Commit fb8524a

Browse files
committed
integrated resonate-and-fire neuronal cell
1 parent 6ec2e7a commit fb8524a

File tree

4 files changed

+254
-0
lines changed

4 files changed

+254
-0
lines changed

ngclearn/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .neurons.spiking.adExCell import AdExCell
1414
from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
1515
from .neurons.spiking.izhikevichCell import IzhikevichCell
16+
from .neurons.spiking.RAFCell import RAFCell
1617
## point to transformer/operater component types
1718
from .other.varTrace import VarTrace
1819
from .other.expKernel import ExpKernel

ngclearn/components/neurons/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
from .spiking.adExCell import AdExCell
1313
from .spiking.fitzhughNagumoCell import FitzhughNagumoCell
1414
from .spiking.izhikevichCell import IzhikevichCell
15+
from .spiking.RAFCell import RAFCell
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
from jax import numpy as jnp, jit
2+
from ngclearn import resolver, Component, Compartment
3+
from ngclearn.components.jaxComponent import JaxComponent
4+
from ngclearn.utils import tensorstats
5+
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
6+
step_euler, step_rk2
7+
8+
@jit
9+
def _update_times(t, s, tols):
10+
"""
11+
Updates time-of-last-spike (tols) variable.
12+
13+
Args:
14+
t: current time (a scalar/int value)
15+
16+
s: binary spike vector
17+
18+
tols: current time-of-last-spike variable
19+
20+
Returns:
21+
updated tols variable
22+
"""
23+
_tols = (1. - s) * tols + (s * t)
24+
return _tols
25+
26+
@jit
27+
def _dfv_internal(j, v, w, tau_m, omega, b): ## "voltage" dynamics
28+
# dy/dt = omega x + b y
29+
dv_dt = omega * w + v * b ## dv/dt
30+
dv_dt = dv_dt * (1./tau_m)
31+
return dv_dt
32+
33+
def _dfv(t, v, params): ## voltage dynamics wrapper
34+
j, w, tau_m, omega, b = params
35+
dv_dt = _dfv_internal(j, v, w, tau_m, omega, b)
36+
return dv_dt
37+
38+
@jit
39+
def _dfw_internal(j, v, w, tau_w, omega, b): ## raw angular driver dynamics
40+
# dx/dt = b x − omega y + I
41+
dw_dt = w * b - v * omega + j
42+
dw_dt = dw_dt * (1./tau_w)
43+
return dw_dt
44+
45+
def _dfw(t, w, params): ## angular driver dynamics wrapper
46+
j, v, tau_w, omega, b = params
47+
dv_dt = _dfw_internal(j, v, w, tau_w, omega, b)
48+
return dv_dt
49+
50+
@jit
51+
def _emit_spike(v, v_thr):
52+
s = (v > v_thr).astype(jnp.float32)
53+
return s
54+
55+
class RAFCell(JaxComponent):
56+
"""
57+
The resonate-and-fire (RAF) neuronal cell
58+
model; a two-variable model. This cell model iteratively evolves
59+
voltage "v" and angular driver "w".
60+
61+
The specific pair of differential equations that characterize this cell
62+
are (for adjusting v and w, given current j, over time):
63+
64+
| tau_m * dv/dt = -(v - v_rest) + sharpV * exp((v - vT)/sharpV) - R_m * w + R_m * j
65+
| tau_w * dw/dt = -w + (v - v_rest) * a
66+
| where w = w + s * (w + b) [in the event of a spike]
67+
68+
| --- Cell Input Compartments: ---
69+
| j - electrical current input (takes in external signals)
70+
| --- Cell State Compartments: ---
71+
| v - membrane potential/voltage state
72+
| w - angular driver variable state
73+
| key - JAX PRNG key
74+
| --- Cell Output Compartments: ---
75+
| s - emitted binary spikes/action potentials
76+
| tols - time-of-last-spike
77+
78+
| References:
79+
| Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks
80+
| 14.6-7 (2001): 883-894.
81+
82+
Args:
83+
name: the string name of this cell
84+
85+
n_units: number of cellular entities (neural population size)
86+
87+
tau_m: membrane time constant (Default: 15 ms)
88+
89+
resist_m: membrane resistance (Default: 1 mega-Ohm)
90+
91+
tau_w: angular driver variable time constant (Default: 400 ms)
92+
93+
thr: voltage/membrane threshold (to obtain action potentials in terms
94+
of binary spikes) (Default: 5 mV)
95+
96+
v_rest: membrane resting potential (Default: -72 mV)
97+
98+
b: oscillation dampening factor (Default: -1.)
99+
100+
v0: initial condition / reset for voltage (Default: -70 mV)
101+
102+
w0: initial condition / reset for angular driver (Default: 0 mV)
103+
104+
integration_type: type of integration to use for this cell's dynamics;
105+
current supported forms include "euler" (Euler/RK-1 integration)
106+
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
107+
108+
:Note: setting the integration type to the midpoint method will
109+
increase the accuray of the estimate of the cell's evolution
110+
at an increase in computational cost (and simulation time)
111+
"""
112+
113+
# Define Functions
114+
def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
115+
omega=10., thr=5., v_rest=-72.,
116+
v_reset=-75., w_reset=0., b=-1., v0=-70., w0=0.,
117+
integration_type="euler", batch_size=1, **kwargs):
118+
super().__init__(name, **kwargs)
119+
120+
## Integration properties
121+
self.integrationType = integration_type
122+
self.intgFlag = get_integrator_code(self.integrationType)
123+
124+
## Cell properties
125+
self.tau_m = tau_m
126+
self.R_m = resist_m
127+
self.tau_w = tau_w
128+
self.omega = omega ## angular frequency
129+
self.b = b ## dampening factor
130+
## note: the smaller b is, the faster the oscillation dampens to resting state values
131+
self.v_rest = v_rest
132+
self.v_reset = v_reset
133+
self.w_reset = w_reset
134+
135+
self.v0 = v0 ## initial membrane potential/voltage condition
136+
self.w0 = w0 ## initial w-parameter condition
137+
self.thr = thr
138+
139+
## Layer Size Setup
140+
self.batch_size = batch_size
141+
self.n_units = n_units
142+
143+
## Compartment setup
144+
restVals = jnp.zeros((self.batch_size, self.n_units))
145+
self.j = Compartment(restVals, display_name="Current", units="mA")
146+
self.v = Compartment(restVals + self.v0, display_name="Voltage", units="mV")
147+
self.w = Compartment(restVals + self.w0, display_name="Angular-Driver")
148+
self.s = Compartment(restVals, display_name="Spikes")
149+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
150+
units="ms") ## time-of-last-spike
151+
152+
@staticmethod
153+
def _advance_state(t, dt, tau_m, R_m, tau_w, thr, omega, b, v_rest,
154+
v_reset, w_reset, intgFlag, j, v, w, tols):
155+
j_ = j * R_m
156+
if intgFlag == 1: ## RK-2/midpoint
157+
w_params = (j_, v, tau_w, omega, b)
158+
_, _w = step_rk2(0., w, _dfw, dt, w_params)
159+
v_params = (j_, w, tau_m, omega, b)
160+
_, _v = step_rk2(0., v, _dfv, dt, v_params)
161+
else: # integType == 0 (default -- Euler)
162+
w_params = (j_, v, tau_w, omega, b)
163+
_, _w = step_euler(0., w, _dfw, dt, w_params)
164+
v_params = (j_, w, tau_m, omega, b)
165+
_, _v = step_euler(0., v, _dfv, dt, v_params)
166+
s = _emit_spike(_v, thr)
167+
## hyperpolarize/reset/snap variables
168+
v = _v * (1. - s) + s * v_reset
169+
w = _w * (1. - s) + s * w_reset
170+
171+
tols = _update_times(t, s, tols)
172+
return j, v, w, s, tols
173+
174+
@resolver(_advance_state)
175+
def advance_state(self, j, v, w, s, tols):
176+
self.j.set(j)
177+
self.w.set(w)
178+
self.v.set(v)
179+
self.s.set(s)
180+
self.tols.set(tols)
181+
182+
@staticmethod
183+
def _reset(batch_size, n_units, v0, w0):
184+
restVals = jnp.zeros((batch_size, n_units))
185+
j = restVals # None
186+
v = restVals + v0
187+
w = restVals + w0
188+
s = restVals #+ 0
189+
tols = restVals #+ 0
190+
return j, v, w, s, tols
191+
192+
@resolver(_reset)
193+
def reset(self, j, v, w, s, tols):
194+
self.j.set(j)
195+
self.v.set(v)
196+
self.w.set(w)
197+
self.s.set(s)
198+
self.tols.set(tols)
199+
200+
@classmethod
201+
def help(cls): ## component help function
202+
properties = {
203+
"cell_type": "RAFCell - evolves neurons according to nonlinear, "
204+
"resonate-and-fire dual-ODE spiking cell dynamics."
205+
}
206+
compartment_props = {
207+
"inputs":
208+
{"j": "External input electrical current",
209+
"key": "JAX PRNG key"},
210+
"states":
211+
{"v": "Membrane potential/voltage at time t",
212+
"w": "Recovery variable at time t"},
213+
"outputs":
214+
{"s": "Emitted spikes/pulses at time t",
215+
"tols": "Time-of-last-spike"},
216+
}
217+
hyperparams = {
218+
"n_units": "Number of neuronal cells to model in this layer",
219+
"batch_size": "Batch size dimension of this component",
220+
"tau_m": "Cell membrane time constant",
221+
"resist_m": "Membrane resistance value",
222+
"tau_w": "Recovery variable time constant",
223+
"v_thr": "Base voltage threshold value",
224+
"v_rest": "Resting membrane potential value",
225+
"v_reset": "Reset membrane potential value",
226+
"b": "Exponential dampening factor applied to oscillations",
227+
"omega": "Angular frequency of neuronal progress per second (radians)",
228+
"v0": "Initial condition for membrane potential/voltage",
229+
"w0": "Initial condition for membrane angular driver variable",
230+
"integration_type": "Type of numerical integration to use for the cell dynamics"
231+
}
232+
info = {cls.__name__: properties,
233+
"compartments": compartment_props,
234+
"dynamics": "tau_m * dv/dt = omega * w + v * b; "
235+
"tau_w * dw/dt = w * b - v * omega + j",
236+
"hyperparameters": hyperparams}
237+
return info
238+
239+
def __repr__(self):
240+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
241+
maxlen = max(len(c) for c in comps) + 5
242+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
243+
for c in comps:
244+
stats = tensorstats(getattr(self, c).value)
245+
if stats is not None:
246+
line = [f"{k}: {v}" for k, v in stats.items()]
247+
line = ", ".join(line)
248+
else:
249+
line = "None"
250+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
251+
return lines

ngclearn/components/neurons/spiking/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from .adExCell import AdExCell
88
from .fitzhughNagumoCell import FitzhughNagumoCell
99
from .izhikevichCell import IzhikevichCell
10+
from .RAFCell import RAFCell

0 commit comments

Comments
 (0)