From 16588fab74e1a1f07e2d66fe2275c5391c107eb5 Mon Sep 17 00:00:00 2001 From: ago109 Date: Thu, 27 Jun 2024 21:24:07 -0400 Subject: [PATCH] added spike-reset/snap-back to fn-cell --- .../neurons/spiking/fitzhughNagumoCell.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py index 32a519840..e3feb206d 100755 --- a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py +++ b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py @@ -113,13 +113,17 @@ class FitzhughNagumoCell(JaxComponent): gamma: power-term divisor (Default: 3.) - v_thr: voltage/membrane threshold (to obtain action potentials in terms - of binary spikes) - v0: initial condition / reset for voltage w0: initial condition / reset for recovery + v_thr: voltage/membrane threshold (to obtain action potentials in terms + of binary spikes) + + spike_reset: if True, once voltage crosses threshold, then dynamics + of voltage and recovery are reset/snapped to initial conditions + (default: False) + integration_type: type of integration to use for this cell's dynamics; current supported forms include "euler" (Euler/RK-1 integration) and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") @@ -131,7 +135,7 @@ class FitzhughNagumoCell(JaxComponent): # Define Functions def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, - beta=0.8, gamma=3., v_thr=1.07, v0=0., w0=0., + beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False, integration_type="euler", **kwargs): super().__init__(name, **kwargs) @@ -150,6 +154,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, self.v0 = v0 ## initial membrane potential/voltage condition self.w0 = w0 ## initial w-parameter condition self.v_thr = v_thr + self.spike_reset = spike_reset ## Layer Size Setup self.batch_size = 1 @@ -164,10 +169,13 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, self.tols = Compartment(restVals) ## time-of-last-spike @staticmethod - def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, alpha, beta, gamma, - intgFlag, j, v, w, tols): + def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha, + beta, gamma, intgFlag, j, v, w, tols): v, w, s = _run_cell(dt, j * R_m, v, w, v_thr, tau_m, tau_w, alpha, beta, gamma, intgFlag) + if spike_reset: ## if spike-reset used, variables snapped back to initial conditions + v = v * (1. - s) + s * v0 + w = w * (1. - s) + s * w0 tols = _update_times(t, s, tols) return j, v, w, s, tols @@ -220,6 +228,8 @@ def help(cls): ## component help function "resist_m": "Membrane resistance value", "tau_w": "Recovery variable time constant", "v_thr": "Base voltage threshold value", + "spike_reset": "Should voltage/recover be snapped to initial " + "condition(s) if spike emitted?", "alpha": "Dimensionless recovery variable shift factor `a", "beta": "Dimensionless recovery variable scale factor `b`", "gamma": "Power-term divisor constant",