Skip to content

Commit

Permalink
added surrogate compartment to LIFCell
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 1, 2024
1 parent e678750 commit a2d2a6d
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions ngclearn/components/neurons/spiking/LIFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import straight_through_estimator

@jit
def _update_times(t, s, tols):
Expand Down Expand Up @@ -235,6 +236,9 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.batch_size = 1
self.n_units = n_units

## set up surrogate function for spike emission
self.spike_fx, self.d_spike_fx = straight_through_estimator()

## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
thr0 = 0.
Expand All @@ -249,17 +253,19 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.rfr = Compartment(restVals + self.refract_T)
self.thr_theta = Compartment(restVals + thr0)
self.tols = Compartment(restVals) ## time-of-last-spike
self.surrogate = Compartment(restVals + 1.) ## surrogate signal

@staticmethod
def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
thr, tau_theta, theta_plus, one_spike, intgFlag,
thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx,
key, j, v, s, rfr, thr_theta, tols):
skey = None ## this is an empty dkey if single_spike mode turned off
if one_spike: ## old code ~> if self.one_spike is False:
key, skey = random.split(key, 2)
## run one integration step for neuronal dynamics
#j = _modify_current(j, dt, tau_m, R_m) ## re-scale current in prep for volt ODE
j = j * R_m
surrogate = d_spike_fx(j)
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
tau_m, v_rest, v_reset, v_decay, refract_T,
intgFlag)
Expand All @@ -268,17 +274,18 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
## update tols
tols = _update_times(t, s, tols)
return v, s, raw_spikes, rfr, thr_theta, tols, key
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate

@resolver(_advance_state)
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key):
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key, surrogate):
self.v.set(v)
self.s.set(s)
self.s_raw.set(s_raw)
self.rfr.set(rfr)
self.thr_theta.set(thr_theta)
self.tols.set(tols)
self.key.set(key)
self.surrogate.set(surrogate)

@staticmethod
def _reset(batch_size, n_units, v_rest, refract_T):
Expand All @@ -290,17 +297,19 @@ def _reset(batch_size, n_units, v_rest, refract_T):
rfr = restVals + refract_T
#thr_theta = restVals ## do not reset thr_theta
tols = restVals #+ 0
return j, v, s, s_raw, rfr, tols
surrogate = restVals + 1.
return j, v, s, s_raw, rfr, tols, surrogate

@resolver(_reset)
def reset(self, j, v, s, s_raw, rfr, tols):
def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
self.j.set(j)
self.v.set(v)
self.s.set(s)
self.s_raw.set(s_raw)
self.rfr.set(rfr)
#self.thr_theta.set(thr_theta)
self.tols.set(tols)
self.surrogate.set(surrogate)

def save(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
Expand Down

0 comments on commit a2d2a6d

Please sign in to comment.