|
| 1 | +from jax import numpy as jnp, jit |
| 2 | +from ngclearn import resolver, Component, Compartment |
| 3 | +from ngclearn.components.jaxComponent import JaxComponent |
| 4 | +from ngclearn.utils import tensorstats |
| 5 | +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ |
| 6 | + step_euler, step_rk2 |
| 7 | + |
| 8 | +@jit |
| 9 | +def _update_times(t, s, tols): |
| 10 | + """ |
| 11 | + Updates time-of-last-spike (tols) variable. |
| 12 | +
|
| 13 | + Args: |
| 14 | + t: current time (a scalar/int value) |
| 15 | +
|
| 16 | + s: binary spike vector |
| 17 | +
|
| 18 | + tols: current time-of-last-spike variable |
| 19 | +
|
| 20 | + Returns: |
| 21 | + updated tols variable |
| 22 | + """ |
| 23 | + _tols = (1. - s) * tols + (s * t) |
| 24 | + return _tols |
| 25 | + |
| 26 | +@jit |
| 27 | +def _dfv_internal(j, v, w, tau_m, omega, b): ## "voltage" dynamics |
| 28 | + # dy/dt = omega x + b y |
| 29 | + dv_dt = omega * w + v * b ## dv/dt |
| 30 | + dv_dt = dv_dt * (1./tau_m) |
| 31 | + return dv_dt |
| 32 | + |
| 33 | +def _dfv(t, v, params): ## voltage dynamics wrapper |
| 34 | + j, w, tau_m, omega, b = params |
| 35 | + dv_dt = _dfv_internal(j, v, w, tau_m, omega, b) |
| 36 | + return dv_dt |
| 37 | + |
| 38 | +@jit |
| 39 | +def _dfw_internal(j, v, w, tau_w, omega, b): ## raw angular driver dynamics |
| 40 | + # dx/dt = b x − omega y + I |
| 41 | + dw_dt = w * b - v * omega + j |
| 42 | + dw_dt = dw_dt * (1./tau_w) |
| 43 | + return dw_dt |
| 44 | + |
| 45 | +def _dfw(t, w, params): ## angular driver dynamics wrapper |
| 46 | + j, v, tau_w, omega, b = params |
| 47 | + dv_dt = _dfw_internal(j, v, w, tau_w, omega, b) |
| 48 | + return dv_dt |
| 49 | + |
| 50 | +@jit |
| 51 | +def _emit_spike(v, v_thr): |
| 52 | + s = (v > v_thr).astype(jnp.float32) |
| 53 | + return s |
| 54 | + |
| 55 | +class RAFCell(JaxComponent): |
| 56 | + """ |
| 57 | + The resonate-and-fire (RAF) neuronal cell |
| 58 | + model; a two-variable model. This cell model iteratively evolves |
| 59 | + voltage "v" and angular driver "w". |
| 60 | +
|
| 61 | + The specific pair of differential equations that characterize this cell |
| 62 | + are (for adjusting v and w, given current j, over time): |
| 63 | +
|
| 64 | + | tau_m * dv/dt = -(v - v_rest) + sharpV * exp((v - vT)/sharpV) - R_m * w + R_m * j |
| 65 | + | tau_w * dw/dt = -w + (v - v_rest) * a |
| 66 | + | where w = w + s * (w + b) [in the event of a spike] |
| 67 | +
|
| 68 | + | --- Cell Input Compartments: --- |
| 69 | + | j - electrical current input (takes in external signals) |
| 70 | + | --- Cell State Compartments: --- |
| 71 | + | v - membrane potential/voltage state |
| 72 | + | w - angular driver variable state |
| 73 | + | key - JAX PRNG key |
| 74 | + | --- Cell Output Compartments: --- |
| 75 | + | s - emitted binary spikes/action potentials |
| 76 | + | tols - time-of-last-spike |
| 77 | +
|
| 78 | + | References: |
| 79 | + | Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks |
| 80 | + | 14.6-7 (2001): 883-894. |
| 81 | +
|
| 82 | + Args: |
| 83 | + name: the string name of this cell |
| 84 | +
|
| 85 | + n_units: number of cellular entities (neural population size) |
| 86 | +
|
| 87 | + tau_m: membrane time constant (Default: 15 ms) |
| 88 | +
|
| 89 | + resist_m: membrane resistance (Default: 1 mega-Ohm) |
| 90 | +
|
| 91 | + tau_w: angular driver variable time constant (Default: 400 ms) |
| 92 | +
|
| 93 | + thr: voltage/membrane threshold (to obtain action potentials in terms |
| 94 | + of binary spikes) (Default: 5 mV) |
| 95 | +
|
| 96 | + v_rest: membrane resting potential (Default: -72 mV) |
| 97 | +
|
| 98 | + b: oscillation dampening factor (Default: -1.) |
| 99 | +
|
| 100 | + v0: initial condition / reset for voltage (Default: -70 mV) |
| 101 | +
|
| 102 | + w0: initial condition / reset for angular driver (Default: 0 mV) |
| 103 | +
|
| 104 | + integration_type: type of integration to use for this cell's dynamics; |
| 105 | + current supported forms include "euler" (Euler/RK-1 integration) |
| 106 | + and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") |
| 107 | +
|
| 108 | + :Note: setting the integration type to the midpoint method will |
| 109 | + increase the accuray of the estimate of the cell's evolution |
| 110 | + at an increase in computational cost (and simulation time) |
| 111 | + """ |
| 112 | + |
| 113 | + # Define Functions |
| 114 | + def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400., |
| 115 | + omega=10., thr=5., v_rest=-72., |
| 116 | + v_reset=-75., w_reset=0., b=-1., v0=-70., w0=0., |
| 117 | + integration_type="euler", batch_size=1, **kwargs): |
| 118 | + super().__init__(name, **kwargs) |
| 119 | + |
| 120 | + ## Integration properties |
| 121 | + self.integrationType = integration_type |
| 122 | + self.intgFlag = get_integrator_code(self.integrationType) |
| 123 | + |
| 124 | + ## Cell properties |
| 125 | + self.tau_m = tau_m |
| 126 | + self.R_m = resist_m |
| 127 | + self.tau_w = tau_w |
| 128 | + self.omega = omega ## angular frequency |
| 129 | + self.b = b ## dampening factor |
| 130 | + ## note: the smaller b is, the faster the oscillation dampens to resting state values |
| 131 | + self.v_rest = v_rest |
| 132 | + self.v_reset = v_reset |
| 133 | + self.w_reset = w_reset |
| 134 | + |
| 135 | + self.v0 = v0 ## initial membrane potential/voltage condition |
| 136 | + self.w0 = w0 ## initial w-parameter condition |
| 137 | + self.thr = thr |
| 138 | + |
| 139 | + ## Layer Size Setup |
| 140 | + self.batch_size = batch_size |
| 141 | + self.n_units = n_units |
| 142 | + |
| 143 | + ## Compartment setup |
| 144 | + restVals = jnp.zeros((self.batch_size, self.n_units)) |
| 145 | + self.j = Compartment(restVals, display_name="Current", units="mA") |
| 146 | + self.v = Compartment(restVals + self.v0, display_name="Voltage", units="mV") |
| 147 | + self.w = Compartment(restVals + self.w0, display_name="Angular-Driver") |
| 148 | + self.s = Compartment(restVals, display_name="Spikes") |
| 149 | + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", |
| 150 | + units="ms") ## time-of-last-spike |
| 151 | + |
| 152 | + @staticmethod |
| 153 | + def _advance_state(t, dt, tau_m, R_m, tau_w, thr, omega, b, v_rest, |
| 154 | + v_reset, w_reset, intgFlag, j, v, w, tols): |
| 155 | + j_ = j * R_m |
| 156 | + if intgFlag == 1: ## RK-2/midpoint |
| 157 | + w_params = (j_, v, tau_w, omega, b) |
| 158 | + _, _w = step_rk2(0., w, _dfw, dt, w_params) |
| 159 | + v_params = (j_, w, tau_m, omega, b) |
| 160 | + _, _v = step_rk2(0., v, _dfv, dt, v_params) |
| 161 | + else: # integType == 0 (default -- Euler) |
| 162 | + w_params = (j_, v, tau_w, omega, b) |
| 163 | + _, _w = step_euler(0., w, _dfw, dt, w_params) |
| 164 | + v_params = (j_, w, tau_m, omega, b) |
| 165 | + _, _v = step_euler(0., v, _dfv, dt, v_params) |
| 166 | + s = _emit_spike(_v, thr) |
| 167 | + ## hyperpolarize/reset/snap variables |
| 168 | + v = _v * (1. - s) + s * v_reset |
| 169 | + w = _w * (1. - s) + s * w_reset |
| 170 | + |
| 171 | + tols = _update_times(t, s, tols) |
| 172 | + return j, v, w, s, tols |
| 173 | + |
| 174 | + @resolver(_advance_state) |
| 175 | + def advance_state(self, j, v, w, s, tols): |
| 176 | + self.j.set(j) |
| 177 | + self.w.set(w) |
| 178 | + self.v.set(v) |
| 179 | + self.s.set(s) |
| 180 | + self.tols.set(tols) |
| 181 | + |
| 182 | + @staticmethod |
| 183 | + def _reset(batch_size, n_units, v0, w0): |
| 184 | + restVals = jnp.zeros((batch_size, n_units)) |
| 185 | + j = restVals # None |
| 186 | + v = restVals + v0 |
| 187 | + w = restVals + w0 |
| 188 | + s = restVals #+ 0 |
| 189 | + tols = restVals #+ 0 |
| 190 | + return j, v, w, s, tols |
| 191 | + |
| 192 | + @resolver(_reset) |
| 193 | + def reset(self, j, v, w, s, tols): |
| 194 | + self.j.set(j) |
| 195 | + self.v.set(v) |
| 196 | + self.w.set(w) |
| 197 | + self.s.set(s) |
| 198 | + self.tols.set(tols) |
| 199 | + |
| 200 | + @classmethod |
| 201 | + def help(cls): ## component help function |
| 202 | + properties = { |
| 203 | + "cell_type": "RAFCell - evolves neurons according to nonlinear, " |
| 204 | + "resonate-and-fire dual-ODE spiking cell dynamics." |
| 205 | + } |
| 206 | + compartment_props = { |
| 207 | + "inputs": |
| 208 | + {"j": "External input electrical current", |
| 209 | + "key": "JAX PRNG key"}, |
| 210 | + "states": |
| 211 | + {"v": "Membrane potential/voltage at time t", |
| 212 | + "w": "Recovery variable at time t"}, |
| 213 | + "outputs": |
| 214 | + {"s": "Emitted spikes/pulses at time t", |
| 215 | + "tols": "Time-of-last-spike"}, |
| 216 | + } |
| 217 | + hyperparams = { |
| 218 | + "n_units": "Number of neuronal cells to model in this layer", |
| 219 | + "batch_size": "Batch size dimension of this component", |
| 220 | + "tau_m": "Cell membrane time constant", |
| 221 | + "resist_m": "Membrane resistance value", |
| 222 | + "tau_w": "Recovery variable time constant", |
| 223 | + "v_thr": "Base voltage threshold value", |
| 224 | + "v_rest": "Resting membrane potential value", |
| 225 | + "v_reset": "Reset membrane potential value", |
| 226 | + "b": "Exponential dampening factor applied to oscillations", |
| 227 | + "omega": "Angular frequency of neuronal progress per second (radians)", |
| 228 | + "v0": "Initial condition for membrane potential/voltage", |
| 229 | + "w0": "Initial condition for membrane angular driver variable", |
| 230 | + "integration_type": "Type of numerical integration to use for the cell dynamics" |
| 231 | + } |
| 232 | + info = {cls.__name__: properties, |
| 233 | + "compartments": compartment_props, |
| 234 | + "dynamics": "tau_m * dv/dt = omega * w + v * b; " |
| 235 | + "tau_w * dw/dt = w * b - v * omega + j", |
| 236 | + "hyperparameters": hyperparams} |
| 237 | + return info |
| 238 | + |
| 239 | + def __repr__(self): |
| 240 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 241 | + maxlen = max(len(c) for c in comps) + 5 |
| 242 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 243 | + for c in comps: |
| 244 | + stats = tensorstats(getattr(self, c).value) |
| 245 | + if stats is not None: |
| 246 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 247 | + line = ", ".join(line) |
| 248 | + else: |
| 249 | + line = "None" |
| 250 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 251 | + return lines |
0 commit comments