@@ -62,16 +62,15 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
62
62
63
63
n_units: number of cellular entities (neural population size)
64
64
65
- tau_m: (Unused -- currently cell is a fixed-point model)
66
-
67
- leakRate: (Unused -- currently cell is a fixed-point model)
65
+ refract_time: relative refractory period time (ms; Default: 0 ms)
68
66
"""
69
- def __init__ (self , name , n_units , batch_size = 1 , ** kwargs ):
67
+ def __init__ (self , name , n_units , refract_time = 0. , batch_size = 1 , ** kwargs ):
70
68
super ().__init__ (name , ** kwargs )
71
69
72
70
## Layer Size Setup
73
71
self .n_units = n_units
74
72
self .batch_size = batch_size
73
+ self .refract_T = refract_time # ms ## refractory period
75
74
76
75
## Convolution shape setup
77
76
self .width = self .height = n_units
@@ -84,39 +83,47 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
84
83
self .target = Compartment (restVals ) # target. input wire
85
84
self .dtarget = Compartment (restVals ) # derivative target
86
85
self .modulator = Compartment (restVals + 1.0 ) # to be set/consumed
86
+ self .rfr = Compartment (restVals + self .refract_T ) ## refractory variable(s)
87
87
88
88
@staticmethod
89
- def _advance_state (dt , mu , dmu , target , dtarget , modulator ):
89
+ def _advance_state (dt , refract_T , mu , dmu , target , dtarget , modulator , rfr ):
90
+ mask = (rfr >= refract_T ) * 1.
90
91
## compute Gaussian error cell output
91
- dmu , dtarget , L = _run_cell (dt , target , mu )
92
- dmu = dmu * modulator
93
- dtarget = dtarget * modulator
94
- return dmu , dtarget , L
92
+ dmu , dtarget , L = _run_cell (dt , target * mask , mu * mask )
93
+ dmu = dmu * modulator * mask
94
+ dtarget = dtarget * modulator * mask
95
+ if refract_T > 0. : ## if non-zero refractory times used, then...
96
+ rfr = (rfr + dt ) * (1. - target ) + target * dt # set refract to dt
97
+ return dmu , dtarget , L , rfr
95
98
96
99
@resolver (_advance_state )
97
- def advance_state (self , dmu , dtarget , L ):
100
+ def advance_state (self , dmu , dtarget , L , rfr ):
98
101
self .dmu .set (dmu )
99
102
self .dtarget .set (dtarget )
100
103
self .L .set (L )
104
+ self .rfr .set (rfr )
101
105
102
106
@staticmethod
103
- def _reset (batch_size , n_units ):
104
- dmu = jnp .zeros ((batch_size , n_units ))
105
- dtarget = jnp .zeros ((batch_size , n_units ))
106
- target = jnp .zeros ((batch_size , n_units )) #None
107
- mu = jnp .zeros ((batch_size , n_units )) #None
107
+ def _reset (refract_T , batch_size , n_units ):
108
+ restVals = jnp .zeros ((batch_size , n_units ))
109
+ dmu = restVals
110
+ dtarget = restVals
111
+ target = restVals
112
+ mu = restVals
108
113
modulator = mu + 1.
109
114
L = 0.
110
- return dmu , dtarget , target , mu , modulator , L
115
+ rfr = restVals + refract_T
116
+ return dmu , dtarget , target , mu , modulator , L , rfr
111
117
112
118
@resolver (_reset )
113
- def reset (self , dmu , dtarget , target , mu , modulator , L ):
119
+ def reset (self , dmu , dtarget , target , mu , modulator , L , rfr ):
114
120
self .dmu .set (dmu )
115
121
self .dtarget .set (dtarget )
116
122
self .target .set (target )
117
123
self .mu .set (mu )
118
124
self .modulator .set (modulator )
119
125
self .L .set (L )
126
+ self .rfr .set (rfr )
120
127
121
128
@classmethod
122
129
def help (cls ): ## component help function
0 commit comments