@@ -52,6 +52,7 @@ class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
52
52
| mu - predicted value (takes in external signals)
53
53
| target - desired/goal value (takes in external signals)
54
54
| modulator - modulation signal (takes in optional external signals)
55
+ | mask - binary/gating mask to apply to error neuron calculations
55
56
| --- Cell Output Compartments: ---
56
57
| L - local loss function embodied by this cell
57
58
| dmu - derivative of L w.r.t. mu
@@ -86,39 +87,45 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
86
87
self .target = Compartment (restVals ) # target. input wire
87
88
self .dtarget = Compartment (restVals ) # derivative target
88
89
self .modulator = Compartment (restVals + 1.0 ) # to be set/consumed
90
+ self .mask = Compartment (restVals + 1.0 )
89
91
90
92
@staticmethod
91
- def _advance_state (dt , mu , target , modulator ):
93
+ def _advance_state (dt , mu , target , modulator , mask ):
92
94
## compute Laplacian error cell output
93
- dmu , dtarget , L = _run_cell (dt , target , mu )
94
- dmu = dmu * modulator
95
- dtarget = dtarget * modulator
96
- return dmu , dtarget , L
95
+ dmu , dtarget , L = _run_cell (dt , target * mask , mu * mask )
96
+ dmu = dmu * modulator * mask
97
+ dtarget = dtarget * modulator * mask
98
+ mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
99
+ return dmu , dtarget , L , mask
97
100
98
101
@resolver (_advance_state )
99
- def advance_state (self , dmu , dtarget , L ):
102
+ def advance_state (self , dmu , dtarget , L , mask ):
100
103
self .dmu .set (dmu )
101
104
self .dtarget .set (dtarget )
102
105
self .L .set (L )
106
+ self .mask .set (mask )
103
107
104
108
@staticmethod
105
109
def _reset (batch_size , n_units ):
106
- dmu = jnp .zeros ((batch_size , n_units ))
107
- dtarget = jnp .zeros ((batch_size , n_units ))
108
- target = jnp .zeros ((batch_size , n_units )) #None
109
- mu = jnp .zeros ((batch_size , n_units )) #None
110
+ restVals = jnp .zeros ((batch_size , n_units ))
111
+ dmu = restVals
112
+ dtarget = restVals
113
+ target = restVals
114
+ mu = restVals
110
115
modulator = mu + 1.
111
116
L = 0.
112
- return dmu , dtarget , target , mu , modulator , L
117
+ mask = jnp .ones ((batch_size , n_units ))
118
+ return dmu , dtarget , target , mu , modulator , L , mask
113
119
114
120
@resolver (_reset )
115
- def reset (self , dmu , dtarget , target , mu , modulator , L ):
121
+ def reset (self , dmu , dtarget , target , mu , modulator , L , mask ):
116
122
self .dmu .set (dmu )
117
123
self .dtarget .set (dtarget )
118
124
self .target .set (target )
119
125
self .mu .set (mu )
120
126
self .modulator .set (modulator )
121
127
self .L .set (L )
128
+ self .mask .set (mask )
122
129
123
130
@classmethod
124
131
def help (cls ): ## component help function
@@ -130,7 +137,8 @@ def help(cls): ## component help function
130
137
"inputs" :
131
138
{"mu" : "External input prediction value(s)" ,
132
139
"target" : "External input target signal value(s)" ,
133
- "modulator" : "External input modulatory/scaling signal(s)" },
140
+ "modulator" : "External input modulatory/scaling signal(s)" ,
141
+ "mask" : "External binary/gating mask to apply to signals" },
134
142
"outputs" :
135
143
{"L" : "Local loss value computed/embodied by this error-cell" ,
136
144
"dmu" : "first derivative of loss w.r.t. prediction value(s)" ,
0 commit comments