From 71d9bf66cd9afad14783422af42f1bd4722b2754 Mon Sep 17 00:00:00 2001 From: ago109 Date: Mon, 1 Jul 2024 13:28:35 -0400 Subject: [PATCH] added masking to gauss err-cell --- .../neurons/graded/gaussianErrorCell.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 00529bfc2..b1c6784cd 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -62,15 +62,16 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell n_units: number of cellular entities (neural population size) - refract_time: relative refractory period time (ms; Default: 0 ms) + tau_m: (Unused -- currently cell is a fixed-point model) + + leakRate: (Unused -- currently cell is a fixed-point model) """ - def __init__(self, name, n_units, refract_time=0., batch_size=1, **kwargs): + def __init__(self, name, n_units, batch_size=1, **kwargs): super().__init__(name, **kwargs) ## Layer Size Setup self.n_units = n_units self.batch_size = batch_size - self.refract_T = refract_time # ms ## refractory period ## Convolution shape setup self.width = self.height = n_units @@ -83,28 +84,26 @@ def __init__(self, name, n_units, refract_time=0., batch_size=1, **kwargs): self.target = Compartment(restVals) # target. input wire self.dtarget = Compartment(restVals) # derivative target self.modulator = Compartment(restVals + 1.0) # to be set/consumed - self.rfr = Compartment(restVals + self.refract_T) ## refractory variable(s) + self.mask = Compartment(restVals + 1.0) @staticmethod - def _advance_state(dt, refract_T, mu, dmu, target, dtarget, modulator, rfr): - mask = (rfr >= refract_T) * 1. + def _advance_state(dt, mu, dmu, target, dtarget, modulator, mask): ## compute Gaussian error cell output dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask) dmu = dmu * modulator * mask dtarget = dtarget * modulator * mask - if refract_T > 0.: ## if non-zero refractory times used, then... - rfr = (rfr + dt) * (1. - target) + target * dt # set refract to dt - return dmu, dtarget, L, rfr + mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t + return dmu, dtarget, L, mask @resolver(_advance_state) - def advance_state(self, dmu, dtarget, L, rfr): + def advance_state(self, dmu, dtarget, L, mask): self.dmu.set(dmu) self.dtarget.set(dtarget) self.L.set(L) - self.rfr.set(rfr) + self.mask.set(mask) @staticmethod - def _reset(refract_T, batch_size, n_units): + def _reset(batch_size, n_units): restVals = jnp.zeros((batch_size, n_units)) dmu = restVals dtarget = restVals @@ -112,18 +111,18 @@ def _reset(refract_T, batch_size, n_units): mu = restVals modulator = mu + 1. L = 0. - rfr = restVals + refract_T - return dmu, dtarget, target, mu, modulator, L, rfr + mask = jnp.ones((batch_size, n_units)) + return dmu, dtarget, target, mu, modulator, L, mask @resolver(_reset) - def reset(self, dmu, dtarget, target, mu, modulator, L, rfr): + def reset(self, dmu, dtarget, target, mu, modulator, L, mask): self.dmu.set(dmu) self.dtarget.set(dtarget) self.target.set(target) self.mu.set(mu) self.modulator.set(modulator) self.L.set(L) - self.rfr.set(rfr) + self.mask.set(mask) @classmethod def help(cls): ## component help function