Skip to content

Commit 55eea4c

Browse files
committed
cleanly integrated masking into gauss/laplace err-cells
1 parent 8ece2be commit 55eea4c

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

ngclearn/components/neurons/graded/gaussianErrorCell.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
5252
| mu - predicted value (takes in external signals)
5353
| target - desired/goal value (takes in external signals)
5454
| modulator - modulation signal (takes in optional external signals)
55+
| mask - binary/gating mask to apply to error neuron calculations
5556
| --- Cell Output Compartments: ---
5657
| L - local loss function embodied by this cell
5758
| dmu - derivative of L w.r.t. mu
@@ -134,7 +135,8 @@ def help(cls): ## component help function
134135
"inputs":
135136
{"mu": "External input prediction value(s)",
136137
"target": "External input target signal value(s)",
137-
"modulator": "External input modulatory/scaling signal(s)"},
138+
"modulator": "External input modulatory/scaling signal(s)",
139+
"mask": "External binary/gating mask to apply to signals"},
138140
"outputs":
139141
{"L": "Local loss value computed/embodied by this error-cell",
140142
"dmu": "first derivative of loss w.r.t. prediction value(s)",

ngclearn/components/neurons/graded/laplacianErrorCell.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
5252
| mu - predicted value (takes in external signals)
5353
| target - desired/goal value (takes in external signals)
5454
| modulator - modulation signal (takes in optional external signals)
55+
| mask - binary/gating mask to apply to error neuron calculations
5556
| --- Cell Output Compartments: ---
5657
| L - local loss function embodied by this cell
5758
| dmu - derivative of L w.r.t. mu
@@ -86,39 +87,45 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
8687
self.target = Compartment(restVals) # target. input wire
8788
self.dtarget = Compartment(restVals) # derivative target
8889
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
90+
self.mask = Compartment(restVals + 1.0)
8991

9092
@staticmethod
91-
def _advance_state(dt, mu, target, modulator):
93+
def _advance_state(dt, mu, target, modulator, mask):
9294
## compute Laplacian error cell output
93-
dmu, dtarget, L = _run_cell(dt, target, mu)
94-
dmu = dmu * modulator
95-
dtarget = dtarget * modulator
96-
return dmu, dtarget, L
95+
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
96+
dmu = dmu * modulator * mask
97+
dtarget = dtarget * modulator * mask
98+
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
99+
return dmu, dtarget, L, mask
97100

98101
@resolver(_advance_state)
99-
def advance_state(self, dmu, dtarget, L):
102+
def advance_state(self, dmu, dtarget, L, mask):
100103
self.dmu.set(dmu)
101104
self.dtarget.set(dtarget)
102105
self.L.set(L)
106+
self.mask.set(mask)
103107

104108
@staticmethod
105109
def _reset(batch_size, n_units):
106-
dmu = jnp.zeros((batch_size, n_units))
107-
dtarget = jnp.zeros((batch_size, n_units))
108-
target = jnp.zeros((batch_size, n_units)) #None
109-
mu = jnp.zeros((batch_size, n_units)) #None
110+
restVals = jnp.zeros((batch_size, n_units))
111+
dmu = restVals
112+
dtarget = restVals
113+
target = restVals
114+
mu = restVals
110115
modulator = mu + 1.
111116
L = 0.
112-
return dmu, dtarget, target, mu, modulator, L
117+
mask = jnp.ones((batch_size, n_units))
118+
return dmu, dtarget, target, mu, modulator, L, mask
113119

114120
@resolver(_reset)
115-
def reset(self, dmu, dtarget, target, mu, modulator, L):
121+
def reset(self, dmu, dtarget, target, mu, modulator, L, mask):
116122
self.dmu.set(dmu)
117123
self.dtarget.set(dtarget)
118124
self.target.set(target)
119125
self.mu.set(mu)
120126
self.modulator.set(modulator)
121127
self.L.set(L)
128+
self.mask.set(mask)
122129

123130
@classmethod
124131
def help(cls): ## component help function
@@ -130,7 +137,8 @@ def help(cls): ## component help function
130137
"inputs":
131138
{"mu": "External input prediction value(s)",
132139
"target": "External input target signal value(s)",
133-
"modulator": "External input modulatory/scaling signal(s)"},
140+
"modulator": "External input modulatory/scaling signal(s)",
141+
"mask": "External binary/gating mask to apply to signals"},
134142
"outputs":
135143
{"L": "Local loss value computed/embodied by this error-cell",
136144
"dmu": "first derivative of loss w.r.t. prediction value(s)",

0 commit comments

Comments
 (0)