Skip to content

Commit

Permalink
added masking to gauss err-cell
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 1, 2024
1 parent 5defaa7 commit 71d9bf6
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions ngclearn/components/neurons/graded/gaussianErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
n_units: number of cellular entities (neural population size)
refract_time: relative refractory period time (ms; Default: 0 ms)
tau_m: (Unused -- currently cell is a fixed-point model)
leakRate: (Unused -- currently cell is a fixed-point model)
"""
def __init__(self, name, n_units, refract_time=0., batch_size=1, **kwargs):
def __init__(self, name, n_units, batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## Layer Size Setup
self.n_units = n_units
self.batch_size = batch_size
self.refract_T = refract_time # ms ## refractory period

## Convolution shape setup
self.width = self.height = n_units
Expand All @@ -83,47 +84,45 @@ def __init__(self, name, n_units, refract_time=0., batch_size=1, **kwargs):
self.target = Compartment(restVals) # target. input wire
self.dtarget = Compartment(restVals) # derivative target
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
self.rfr = Compartment(restVals + self.refract_T) ## refractory variable(s)
self.mask = Compartment(restVals + 1.0)

@staticmethod
def _advance_state(dt, refract_T, mu, dmu, target, dtarget, modulator, rfr):
mask = (rfr >= refract_T) * 1.
def _advance_state(dt, mu, dmu, target, dtarget, modulator, mask):
## compute Gaussian error cell output
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
dmu = dmu * modulator * mask
dtarget = dtarget * modulator * mask
if refract_T > 0.: ## if non-zero refractory times used, then...
rfr = (rfr + dt) * (1. - target) + target * dt # set refract to dt
return dmu, dtarget, L, rfr
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
return dmu, dtarget, L, mask

@resolver(_advance_state)
def advance_state(self, dmu, dtarget, L, rfr):
def advance_state(self, dmu, dtarget, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.L.set(L)
self.rfr.set(rfr)
self.mask.set(mask)

@staticmethod
def _reset(refract_T, batch_size, n_units):
def _reset(batch_size, n_units):
restVals = jnp.zeros((batch_size, n_units))
dmu = restVals
dtarget = restVals
target = restVals
mu = restVals
modulator = mu + 1.
L = 0.
rfr = restVals + refract_T
return dmu, dtarget, target, mu, modulator, L, rfr
mask = jnp.ones((batch_size, n_units))
return dmu, dtarget, target, mu, modulator, L, mask

@resolver(_reset)
def reset(self, dmu, dtarget, target, mu, modulator, L, rfr):
def reset(self, dmu, dtarget, target, mu, modulator, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.target.set(target)
self.mu.set(mu)
self.modulator.set(modulator)
self.L.set(L)
self.rfr.set(rfr)
self.mask.set(mask)

@classmethod
def help(cls): ## component help function
Expand Down

0 comments on commit 71d9bf6

Please sign in to comment.