Skip to content

Commit

Permalink
cleanly integrated masking into gauss/laplace err-cells
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 1, 2024
1 parent 8ece2be commit 55eea4c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
4 changes: 3 additions & 1 deletion ngclearn/components/neurons/graded/gaussianErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
| mu - predicted value (takes in external signals)
| target - desired/goal value (takes in external signals)
| modulator - modulation signal (takes in optional external signals)
| mask - binary/gating mask to apply to error neuron calculations
| --- Cell Output Compartments: ---
| L - local loss function embodied by this cell
| dmu - derivative of L w.r.t. mu
Expand Down Expand Up @@ -134,7 +135,8 @@ def help(cls): ## component help function
"inputs":
{"mu": "External input prediction value(s)",
"target": "External input target signal value(s)",
"modulator": "External input modulatory/scaling signal(s)"},
"modulator": "External input modulatory/scaling signal(s)",
"mask": "External binary/gating mask to apply to signals"},
"outputs":
{"L": "Local loss value computed/embodied by this error-cell",
"dmu": "first derivative of loss w.r.t. prediction value(s)",
Expand Down
34 changes: 21 additions & 13 deletions ngclearn/components/neurons/graded/laplacianErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
| mu - predicted value (takes in external signals)
| target - desired/goal value (takes in external signals)
| modulator - modulation signal (takes in optional external signals)
| mask - binary/gating mask to apply to error neuron calculations
| --- Cell Output Compartments: ---
| L - local loss function embodied by this cell
| dmu - derivative of L w.r.t. mu
Expand Down Expand Up @@ -86,39 +87,45 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
self.target = Compartment(restVals) # target. input wire
self.dtarget = Compartment(restVals) # derivative target
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
self.mask = Compartment(restVals + 1.0)

@staticmethod
def _advance_state(dt, mu, target, modulator):
def _advance_state(dt, mu, target, modulator, mask):
## compute Laplacian error cell output
dmu, dtarget, L = _run_cell(dt, target, mu)
dmu = dmu * modulator
dtarget = dtarget * modulator
return dmu, dtarget, L
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
dmu = dmu * modulator * mask
dtarget = dtarget * modulator * mask
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
return dmu, dtarget, L, mask

@resolver(_advance_state)
def advance_state(self, dmu, dtarget, L):
def advance_state(self, dmu, dtarget, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.L.set(L)
self.mask.set(mask)

@staticmethod
def _reset(batch_size, n_units):
dmu = jnp.zeros((batch_size, n_units))
dtarget = jnp.zeros((batch_size, n_units))
target = jnp.zeros((batch_size, n_units)) #None
mu = jnp.zeros((batch_size, n_units)) #None
restVals = jnp.zeros((batch_size, n_units))
dmu = restVals
dtarget = restVals
target = restVals
mu = restVals
modulator = mu + 1.
L = 0.
return dmu, dtarget, target, mu, modulator, L
mask = jnp.ones((batch_size, n_units))
return dmu, dtarget, target, mu, modulator, L, mask

@resolver(_reset)
def reset(self, dmu, dtarget, target, mu, modulator, L):
def reset(self, dmu, dtarget, target, mu, modulator, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.target.set(target)
self.mu.set(mu)
self.modulator.set(modulator)
self.L.set(L)
self.mask.set(mask)

@classmethod
def help(cls): ## component help function
Expand All @@ -130,7 +137,8 @@ def help(cls): ## component help function
"inputs":
{"mu": "External input prediction value(s)",
"target": "External input target signal value(s)",
"modulator": "External input modulatory/scaling signal(s)"},
"modulator": "External input modulatory/scaling signal(s)",
"mask": "External binary/gating mask to apply to signals"},
"outputs":
{"L": "Local loss value computed/embodied by this error-cell",
"dmu": "first derivative of loss w.r.t. prediction value(s)",
Expand Down

0 comments on commit 55eea4c

Please sign in to comment.