@@ -52,6 +52,7 @@ class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
5252 | mu - predicted value (takes in external signals)
5353 | target - desired/goal value (takes in external signals)
5454 | modulator - modulation signal (takes in optional external signals)
55+ | mask - binary/gating mask to apply to error neuron calculations
5556 | --- Cell Output Compartments: ---
5657 | L - local loss function embodied by this cell
5758 | dmu - derivative of L w.r.t. mu
@@ -86,39 +87,45 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
8687 self .target = Compartment (restVals ) # target. input wire
8788 self .dtarget = Compartment (restVals ) # derivative target
8889 self .modulator = Compartment (restVals + 1.0 ) # to be set/consumed
90+ self .mask = Compartment (restVals + 1.0 )
8991
9092 @staticmethod
91- def _advance_state (dt , mu , target , modulator ):
93+ def _advance_state (dt , mu , target , modulator , mask ):
9294 ## compute Laplacian error cell output
93- dmu , dtarget , L = _run_cell (dt , target , mu )
94- dmu = dmu * modulator
95- dtarget = dtarget * modulator
96- return dmu , dtarget , L
95+ dmu , dtarget , L = _run_cell (dt , target * mask , mu * mask )
96+ dmu = dmu * modulator * mask
97+ dtarget = dtarget * modulator * mask
98+ mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
99+ return dmu , dtarget , L , mask
97100
98101 @resolver (_advance_state )
99- def advance_state (self , dmu , dtarget , L ):
102+ def advance_state (self , dmu , dtarget , L , mask ):
100103 self .dmu .set (dmu )
101104 self .dtarget .set (dtarget )
102105 self .L .set (L )
106+ self .mask .set (mask )
103107
104108 @staticmethod
105109 def _reset (batch_size , n_units ):
106- dmu = jnp .zeros ((batch_size , n_units ))
107- dtarget = jnp .zeros ((batch_size , n_units ))
108- target = jnp .zeros ((batch_size , n_units )) #None
109- mu = jnp .zeros ((batch_size , n_units )) #None
110+ restVals = jnp .zeros ((batch_size , n_units ))
111+ dmu = restVals
112+ dtarget = restVals
113+ target = restVals
114+ mu = restVals
110115 modulator = mu + 1.
111116 L = 0.
112- return dmu , dtarget , target , mu , modulator , L
117+ mask = jnp .ones ((batch_size , n_units ))
118+ return dmu , dtarget , target , mu , modulator , L , mask
113119
114120 @resolver (_reset )
115- def reset (self , dmu , dtarget , target , mu , modulator , L ):
121+ def reset (self , dmu , dtarget , target , mu , modulator , L , mask ):
116122 self .dmu .set (dmu )
117123 self .dtarget .set (dtarget )
118124 self .target .set (target )
119125 self .mu .set (mu )
120126 self .modulator .set (modulator )
121127 self .L .set (L )
128+ self .mask .set (mask )
122129
123130 @classmethod
124131 def help (cls ): ## component help function
@@ -130,7 +137,8 @@ def help(cls): ## component help function
130137 "inputs" :
131138 {"mu" : "External input prediction value(s)" ,
132139 "target" : "External input target signal value(s)" ,
133- "modulator" : "External input modulatory/scaling signal(s)" },
140+ "modulator" : "External input modulatory/scaling signal(s)" ,
141+ "mask" : "External binary/gating mask to apply to signals" },
134142 "outputs" :
135143 {"L" : "Local loss value computed/embodied by this error-cell" ,
136144 "dmu" : "first derivative of loss w.r.t. prediction value(s)" ,
0 commit comments