|
| 1 | +from jax import numpy as jnp, random, jit, nn |
| 2 | +from ngclearn.utils import tensorstats |
| 3 | +from ngcsimlib.deprecators import deprecate_args |
| 4 | +from ngclearn import resolver, Component, Compartment |
| 5 | +from ngclearn.components.jaxComponent import JaxComponent |
| 6 | +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ |
| 7 | + step_euler, step_rk2 |
| 8 | +from ngclearn.utils.surrogate_fx import (arctan_estimator, |
| 9 | + triangular_estimator, |
| 10 | + straight_through_estimator) |
| 11 | + |
| 12 | +@jit |
| 13 | +def _update_times(t, s, tols): |
| 14 | + """ |
| 15 | + Updates time-of-last-spike (tols) variable. |
| 16 | +
|
| 17 | + Args: |
| 18 | + t: current time (a scalar/int value) |
| 19 | +
|
| 20 | + s: binary spike vector |
| 21 | +
|
| 22 | + tols: current time-of-last-spike variable |
| 23 | +
|
| 24 | + Returns: |
| 25 | + updated tols variable |
| 26 | + """ |
| 27 | + _tols = (1. - s) * tols + (s * t) |
| 28 | + return _tols |
| 29 | + |
| 30 | +@jit |
| 31 | +def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics |
| 32 | + mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask |
| 33 | + ## update voltage / membrane potential |
| 34 | + dv_dt = (j * mask) ## integration only involves electrical current |
| 35 | + dv_dt = dv_dt * (1./tau_m) |
| 36 | + return dv_dt |
| 37 | + |
| 38 | +def _dfv(t, v, params): ## voltage dynamics wrapper |
| 39 | + j, rfr, tau_m, refract_T = params |
| 40 | + dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T) |
| 41 | + return dv_dt |
| 42 | + |
| 43 | +def _run_cell(dt, j, v, v_thr, rfr, tau_m, v_rest, v_reset, refract_T, integType=0): |
| 44 | + ### Runs integrator (or integrate-and-fire; IF) neuronal dynamics |
| 45 | + ## update voltage / membrane potential |
| 46 | + v_params = (j, rfr, tau_m, refract_T) |
| 47 | + if integType == 1: |
| 48 | + _, _v = step_rk2(0., v, _dfv, dt, v_params) |
| 49 | + else: |
| 50 | + _, _v = step_euler(0., v, _dfv, dt, v_params) |
| 51 | + ## obtain action potentials/spikes |
| 52 | + s = (_v > v_thr).astype(jnp.float32) |
| 53 | + ## update refractory variables |
| 54 | + _rfr = (rfr + dt) * (1. - s) |
| 55 | + ## perform hyper-polarization of neuronal cells |
| 56 | + _v = _v * (1. - s) + s * v_reset |
| 57 | + return _v, s, _rfr |
| 58 | + |
| 59 | +class IFCell(JaxComponent): ## integrate-and-fire cell |
| 60 | + """ |
| 61 | + A spiking cell based on integrate-and-fire (IF) neuronal dynamics. |
| 62 | +
|
| 63 | + The specific differential equation that characterizes this cell |
| 64 | + is (for adjusting v, given current j, over time) is: |
| 65 | +
|
| 66 | + | tau_m * dv/dt = (v_rest - v) + j * R |
| 67 | + | where R is the membrane resistance and v_rest is the resting potential |
| 68 | + | also, if a spike occurs, v is set to v_reset |
| 69 | +
|
| 70 | + | --- Cell Input Compartments: --- |
| 71 | + | j - electrical current input (takes in external signals) |
| 72 | + | --- Cell State Compartments: --- |
| 73 | + | v - membrane potential/voltage state |
| 74 | + | rfr - (relative) refractory variable state |
| 75 | + | key - JAX PRNG key |
| 76 | + | --- Cell Output Compartments: --- |
| 77 | + | s - emitted binary spikes/action potentials |
| 78 | + | s_raw - raw spike signals before post-processing (only if one_spike = True, else s_raw = s) |
| 79 | + | tols - time-of-last-spike |
| 80 | +
|
| 81 | + Args: |
| 82 | + name: the string name of this cell |
| 83 | +
|
| 84 | + n_units: number of cellular entities (neural population size) |
| 85 | +
|
| 86 | + tau_m: membrane time constant |
| 87 | +
|
| 88 | + resist_m: membrane resistance value (default: 1) |
| 89 | +
|
| 90 | + thr: base value for adaptive thresholds that govern short-term |
| 91 | + plasticity (in milliVolts, or mV; default: -52. mV) |
| 92 | +
|
| 93 | + v_rest: membrane resting potential (in mV; default: -65 mV) |
| 94 | +
|
| 95 | + v_reset: membrane reset potential (in mV) -- upon occurrence of a spike, |
| 96 | + a neuronal cell's membrane potential will be set to this value; |
| 97 | + (default: -60 mV) |
| 98 | +
|
| 99 | + refract_time: relative refractory period time (ms; default: 0 ms) |
| 100 | +
|
| 101 | + integration_type: type of integration to use for this cell's dynamics; |
| 102 | + current supported forms include "euler" (Euler/RK-1 integration) |
| 103 | + and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") |
| 104 | +
|
| 105 | + :Note: setting the integration type to the midpoint method will |
| 106 | + increase the accuray of the estimate of the cell's evolution |
| 107 | + at an increase in computational cost (and simulation time) |
| 108 | +
|
| 109 | + surrgoate_type: type of surrogate function to use for approximating a |
| 110 | + partial derivative of this cell's spikes w.r.t. its voltage/current |
| 111 | + (default: "straight_through") |
| 112 | +
|
| 113 | + :Note: surrogate options available include: "straight_through" |
| 114 | + (straight-through estimator), "triangular" (triangular estimator), |
| 115 | + and "arctan" (arc-tangent estimator) |
| 116 | +
|
| 117 | + lower_clamp_voltage: if True, this will ensure voltage never is below |
| 118 | + the value of `v_rest` (default: True) |
| 119 | + """ |
| 120 | + |
| 121 | + @deprecate_args(thr_jitter=None) |
| 122 | + def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., |
| 123 | + v_reset=-60., refract_time=0., integration_type="euler", |
| 124 | + surrgoate_type="straight_through", lower_clamp_voltage=True, |
| 125 | + **kwargs): |
| 126 | + super().__init__(name, **kwargs) |
| 127 | + |
| 128 | + ## Integration properties |
| 129 | + self.integrationType = integration_type |
| 130 | + self.intgFlag = get_integrator_code(self.integrationType) |
| 131 | + |
| 132 | + ## membrane parameter setup (affects ODE integration) |
| 133 | + self.tau_m = tau_m ## membrane time constant |
| 134 | + self.resist_m = resist_m ## resistance value |
| 135 | + |
| 136 | + self.v_rest = v_rest #-65. # mV |
| 137 | + self.v_reset = v_reset # -60. # -65. # mV (milli-volts) |
| 138 | + ## basic asserts to prevent neuronal dynamics breaking... |
| 139 | + assert self.resist_m > 0. |
| 140 | + self.refract_T = refract_time #5. # 2. ## refractory period # ms |
| 141 | + self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV |
| 142 | + self.lower_clamp_voltage = lower_clamp_voltage |
| 143 | + |
| 144 | + ## Layer Size Setup |
| 145 | + self.batch_size = 1 |
| 146 | + self.n_units = n_units |
| 147 | + |
| 148 | + ## set up surrogate function for spike emission |
| 149 | + if surrgoate_type == "arctan": |
| 150 | + self.spike_fx, self.d_spike_fx = arctan_estimator() |
| 151 | + elif surrgoate_type == "triangular": |
| 152 | + self.spike_fx, self.d_spike_fx = triangular_estimator() |
| 153 | + else: ## default: straight_through |
| 154 | + self.spike_fx, self.d_spike_fx = straight_through_estimator() |
| 155 | + |
| 156 | + |
| 157 | + ## Compartment setup |
| 158 | + restVals = jnp.zeros((self.batch_size, self.n_units)) |
| 159 | + self.j = Compartment(restVals, display_name="Current", units="mA") |
| 160 | + self.v = Compartment(restVals + self.v_rest, |
| 161 | + display_name="Voltage", units="mV") |
| 162 | + self.s = Compartment(restVals, display_name="Spikes") |
| 163 | + self.rfr = Compartment(restVals + self.refract_T, |
| 164 | + display_name="Refractory Time Period", units="ms") |
| 165 | + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", |
| 166 | + units="ms") ## time-of-last-spike |
| 167 | + self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") |
| 168 | + |
| 169 | + @staticmethod |
| 170 | + def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, refract_T, |
| 171 | + thr, lower_clamp_voltage, intgFlag, d_spike_fx, key, |
| 172 | + j, v, rfr, tols): |
| 173 | + ## run one integration step for neuronal dynamics |
| 174 | + j = j * resist_m |
| 175 | + v, s, rfr = _run_cell(dt, j, v, thr, rfr, tau_m, v_rest, v_reset, |
| 176 | + refract_T, intgFlag) |
| 177 | + surrogate = d_spike_fx(v, thr) |
| 178 | + ## update tols |
| 179 | + tols = _update_times(t, s, tols) |
| 180 | + if lower_clamp_voltage: ## ensure voltage never < v_rest |
| 181 | + v = jnp.maximum(v, v_rest) |
| 182 | + return v, s, rfr, tols, key, surrogate |
| 183 | + |
| 184 | + @resolver(_advance_state) |
| 185 | + def advance_state(self, v, s, rfr, tols, key, surrogate): |
| 186 | + self.v.set(v) |
| 187 | + self.s.set(s) |
| 188 | + self.rfr.set(rfr) |
| 189 | + self.tols.set(tols) |
| 190 | + self.key.set(key) |
| 191 | + self.surrogate.set(surrogate) |
| 192 | + |
| 193 | + @staticmethod |
| 194 | + def _reset(batch_size, n_units, v_rest, refract_T): |
| 195 | + restVals = jnp.zeros((batch_size, n_units)) |
| 196 | + j = restVals #+ 0 |
| 197 | + v = restVals + v_rest |
| 198 | + s = restVals #+ 0 |
| 199 | + rfr = restVals + refract_T |
| 200 | + tols = restVals #+ 0 |
| 201 | + surrogate = restVals + 1. |
| 202 | + return j, v, s, rfr, tols, surrogate |
| 203 | + |
| 204 | + @resolver(_reset) |
| 205 | + def reset(self, j, v, s, rfr, tols, surrogate): |
| 206 | + self.j.set(j) |
| 207 | + self.v.set(v) |
| 208 | + self.s.set(s) |
| 209 | + self.rfr.set(rfr) |
| 210 | + self.tols.set(tols) |
| 211 | + self.surrogate.set(surrogate) |
| 212 | + |
| 213 | + def save(self, directory, **kwargs): |
| 214 | + ## do a protected save of constants, depending on whether they are floats or arrays |
| 215 | + tau_m = (self.tau_m if isinstance(self.tau_m, float) |
| 216 | + else jnp.ones([[self.tau_m]])) |
| 217 | + thr = (self.thr if isinstance(self.thr, float) |
| 218 | + else jnp.ones([[self.thr]])) |
| 219 | + v_rest = (self.v_rest if isinstance(self.v_rest, float) |
| 220 | + else jnp.ones([[self.v_rest]])) |
| 221 | + v_reset = (self.v_reset if isinstance(self.v_reset, float) |
| 222 | + else jnp.ones([[self.v_reset]])) |
| 223 | + v_decay = (self.v_decay if isinstance(self.v_decay, float) |
| 224 | + else jnp.ones([[self.v_decay]])) |
| 225 | + resist_m = (self.resist_m if isinstance(self.resist_m, float) |
| 226 | + else jnp.ones([[self.resist_m]])) |
| 227 | + tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) |
| 228 | + else jnp.ones([[self.tau_theta]])) |
| 229 | + theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) |
| 230 | + else jnp.ones([[self.theta_plus]])) |
| 231 | + |
| 232 | + file_name = directory + "/" + self.name + ".npz" |
| 233 | + jnp.savez(file_name, |
| 234 | + tau_m=tau_m, thr=thr, v_rest=v_rest, |
| 235 | + v_reset=v_reset, v_decay=v_decay, |
| 236 | + resist_m=resist_m, tau_theta=tau_theta, |
| 237 | + theta_plus=theta_plus, |
| 238 | + key=self.key.value) |
| 239 | + |
| 240 | + def load(self, directory, seeded=False, **kwargs): |
| 241 | + file_name = directory + "/" + self.name + ".npz" |
| 242 | + data = jnp.load(file_name) |
| 243 | + ## constants loaded in |
| 244 | + self.tau_m = data['tau_m'] |
| 245 | + self.thr = data['thr'] |
| 246 | + self.v_rest = data['v_rest'] |
| 247 | + self.v_reset = data['v_reset'] |
| 248 | + self.v_decay = data['v_decay'] |
| 249 | + self.resist_m = data['resist_m'] |
| 250 | + self.tau_theta = data['tau_theta'] |
| 251 | + self.theta_plus = data['theta_plus'] |
| 252 | + |
| 253 | + if seeded: |
| 254 | + self.key.set(data['key']) |
| 255 | + |
| 256 | + @classmethod |
| 257 | + def help(cls): ## component help function |
| 258 | + properties = { |
| 259 | + "cell_type": "IFCell - evolves neurons according to integrate-" |
| 260 | + "and-fire spiking dynamics." |
| 261 | + } |
| 262 | + compartment_props = { |
| 263 | + "inputs": |
| 264 | + {"j": "External input electrical current"}, |
| 265 | + "states": |
| 266 | + {"v": "Membrane potential/voltage at time t", |
| 267 | + "rfr": "Current state of (relative) refractory variable", |
| 268 | + "thr": "Current state of voltage threshold at time t", |
| 269 | + "key": "JAX PRNG key"}, |
| 270 | + "outputs": |
| 271 | + {"s": "Emitted spikes/pulses at time t", |
| 272 | + "tols": "Time-of-last-spike"}, |
| 273 | + } |
| 274 | + hyperparams = { |
| 275 | + "n_units": "Number of neuronal cells to model in this layer", |
| 276 | + "tau_m": "Cell membrane time constant", |
| 277 | + "resist_m": "Membrane resistance value", |
| 278 | + "thr": "Base voltage threshold value", |
| 279 | + "v_rest": "Resting membrane potential value", |
| 280 | + "v_reset": "Reset membrane potential value", |
| 281 | + "refract_time": "Length of relative refractory period (ms)", |
| 282 | + "integration_type": "Type of numerical integration to use for the cell dynamics", |
| 283 | + "surrgoate_type": "Type of surrogate function to use approximate " |
| 284 | + "derivative of spike w.r.t. voltage/current", |
| 285 | + "lower_bound_clamp": "Should voltage be lower bounded to be never be below `v_rest`" |
| 286 | + } |
| 287 | + info = {cls.__name__: properties, |
| 288 | + "compartments": compartment_props, |
| 289 | + "dynamics": "tau_m * dv/dt = (v_rest - v) + j * resist_m", |
| 290 | + "hyperparameters": hyperparams} |
| 291 | + return info |
| 292 | + |
| 293 | + def __repr__(self): |
| 294 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 295 | + maxlen = max(len(c) for c in comps) + 5 |
| 296 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 297 | + for c in comps: |
| 298 | + stats = tensorstats(getattr(self, c).value) |
| 299 | + if stats is not None: |
| 300 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 301 | + line = ", ".join(line) |
| 302 | + else: |
| 303 | + line = "None" |
| 304 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 305 | + return lines |
| 306 | + |
0 commit comments