Skip to content

Commit 5defaa7

Browse files
committed
added refractoriness to gauss err-cell
1 parent 27c8a89 commit 5defaa7

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

ngclearn/components/neurons/graded/gaussianErrorCell.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,15 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
6262
6363
n_units: number of cellular entities (neural population size)
6464
65-
tau_m: (Unused -- currently cell is a fixed-point model)
66-
67-
leakRate: (Unused -- currently cell is a fixed-point model)
65+
refract_time: relative refractory period time (ms; Default: 0 ms)
6866
"""
69-
def __init__(self, name, n_units, batch_size=1, **kwargs):
67+
def __init__(self, name, n_units, refract_time=0., batch_size=1, **kwargs):
7068
super().__init__(name, **kwargs)
7169

7270
## Layer Size Setup
7371
self.n_units = n_units
7472
self.batch_size = batch_size
73+
self.refract_T = refract_time # ms ## refractory period
7574

7675
## Convolution shape setup
7776
self.width = self.height = n_units
@@ -84,39 +83,47 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
8483
self.target = Compartment(restVals) # target. input wire
8584
self.dtarget = Compartment(restVals) # derivative target
8685
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
86+
self.rfr = Compartment(restVals + self.refract_T) ## refractory variable(s)
8787

8888
@staticmethod
89-
def _advance_state(dt, mu, dmu, target, dtarget, modulator):
89+
def _advance_state(dt, refract_T, mu, dmu, target, dtarget, modulator, rfr):
90+
mask = (rfr >= refract_T) * 1.
9091
## compute Gaussian error cell output
91-
dmu, dtarget, L = _run_cell(dt, target, mu)
92-
dmu = dmu * modulator
93-
dtarget = dtarget * modulator
94-
return dmu, dtarget, L
92+
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
93+
dmu = dmu * modulator * mask
94+
dtarget = dtarget * modulator * mask
95+
if refract_T > 0.: ## if non-zero refractory times used, then...
96+
rfr = (rfr + dt) * (1. - target) + target * dt # set refract to dt
97+
return dmu, dtarget, L, rfr
9598

9699
@resolver(_advance_state)
97-
def advance_state(self, dmu, dtarget, L):
100+
def advance_state(self, dmu, dtarget, L, rfr):
98101
self.dmu.set(dmu)
99102
self.dtarget.set(dtarget)
100103
self.L.set(L)
104+
self.rfr.set(rfr)
101105

102106
@staticmethod
103-
def _reset(batch_size, n_units):
104-
dmu = jnp.zeros((batch_size, n_units))
105-
dtarget = jnp.zeros((batch_size, n_units))
106-
target = jnp.zeros((batch_size, n_units)) #None
107-
mu = jnp.zeros((batch_size, n_units)) #None
107+
def _reset(refract_T, batch_size, n_units):
108+
restVals = jnp.zeros((batch_size, n_units))
109+
dmu = restVals
110+
dtarget = restVals
111+
target = restVals
112+
mu = restVals
108113
modulator = mu + 1.
109114
L = 0.
110-
return dmu, dtarget, target, mu, modulator, L
115+
rfr = restVals + refract_T
116+
return dmu, dtarget, target, mu, modulator, L, rfr
111117

112118
@resolver(_reset)
113-
def reset(self, dmu, dtarget, target, mu, modulator, L):
119+
def reset(self, dmu, dtarget, target, mu, modulator, L, rfr):
114120
self.dmu.set(dmu)
115121
self.dtarget.set(dtarget)
116122
self.target.set(target)
117123
self.mu.set(mu)
118124
self.modulator.set(modulator)
119125
self.L.set(L)
126+
self.rfr.set(rfr)
120127

121128
@classmethod
122129
def help(cls): ## component help function

0 commit comments

Comments
 (0)