@@ -149,13 +149,16 @@ class LatencyCell(JaxComponent):
149
149
:Note: if this set to True, you will need to choose a useful value
150
150
for the "num_steps" argument (>1), depending on how many steps simulated
151
151
152
+ clip_spikes: should values under threshold be removed/suppressed?
153
+ (default: False)
154
+
152
155
num_steps: number of discrete time steps to consider for normalized latency
153
156
code (only useful if "normalize" is set to True) (Default: 1)
154
157
"""
155
158
156
159
# Define Functions
157
160
def __init__ (self , name , n_units , tau = 1. , threshold = 0.01 , first_spike_time = 0. ,
158
- linearize = False , normalize = False , num_steps = 1. ,
161
+ linearize = False , normalize = False , clip_spikes = False , num_steps = 1. ,
159
162
batch_size = 1 , ** kwargs ):
160
163
super ().__init__ (name , ** kwargs )
161
164
@@ -164,6 +167,7 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
164
167
self .tau = tau
165
168
self .threshold = threshold
166
169
self .linearize = linearize
170
+ self .clip_spikes = clip_spikes
167
171
## normalize latency code s.t. final spike(s) occur w/in num_steps
168
172
self .normalize = normalize
169
173
self .num_steps = num_steps
@@ -176,17 +180,22 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
176
180
restVals = jnp .zeros ((self .batch_size , self .n_units ))
177
181
self .inputs = Compartment (restVals , display_name = "Input Stimulus" ) # input compartment
178
182
self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
179
- self .mask = Compartment (restVals , display_name = "Mask Variable" ) # output compartment
183
+ self .mask = Compartment (restVals , display_name = "Spike Time Mask" )
184
+ self .clip_mask = Compartment (restVals , display_name = "Clip Mask" )
180
185
self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
181
186
self .targ_sp_times = Compartment (restVals , display_name = "Target Spike Time" , units = "ms" )
182
187
#self.reset()
183
188
184
189
@staticmethod
185
190
def _calc_spike_times (linearize , tau , threshold , first_spike_time , num_steps ,
186
- normalize , inputs ):
191
+ normalize , clip_spikes , inputs ):
187
192
## would call this function before processing a spike train (at start)
188
193
data = inputs
189
- if linearize == True : ## linearize spike time calculation
194
+ if clip_spikes :
195
+ clip_mask = (data < threshold ) * 1. ## find values under threshold
196
+ else :
197
+ clip_mask = data * 0. ## all values allowed to fire spikes
198
+ if linearize : ## linearize spike time calculation
190
199
stimes = _calc_spike_times_linear (data , tau , threshold ,
191
200
first_spike_time ,
192
201
num_steps , normalize )
@@ -197,18 +206,20 @@ def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
197
206
num_steps = num_steps ,
198
207
normalize = normalize )
199
208
targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
200
- return targ_sp_times
209
+ return targ_sp_times , clip_mask
201
210
202
211
@resolver (_calc_spike_times )
203
- def calc_spike_times (self , targ_sp_times ):
212
+ def calc_spike_times (self , targ_sp_times , clip_mask ):
204
213
self .targ_sp_times .set (targ_sp_times )
214
+ self .clip_mask .set (clip_mask )
205
215
206
216
@staticmethod
207
- def _advance_state (t , dt , key , inputs , mask , targ_sp_times , tols ):
217
+ def _advance_state (t , dt , key , inputs , mask , clip_mask , targ_sp_times , tols ):
208
218
key , * subkeys = random .split (key , 2 )
209
219
data = inputs ## get sensory pattern data / features
210
220
spikes , spk_mask = _extract_spike (targ_sp_times , t , mask ) ## get spikes at t
211
221
tols = _update_times (t , spikes , tols )
222
+ spikes = spikes * (1. - clip_mask )
212
223
return spikes , tols , spk_mask , targ_sp_times , key
213
224
214
225
@resolver (_advance_state )
@@ -222,14 +233,15 @@ def advance_state(self, outputs, tols, mask, targ_sp_times, key):
222
233
@staticmethod
223
234
def _reset (batch_size , n_units ):
224
235
restVals = jnp .zeros ((batch_size , n_units ))
225
- return (restVals , restVals , restVals , restVals , restVals )
236
+ return (restVals , restVals , restVals , restVals , restVals , restVals )
226
237
227
238
@resolver (_reset )
228
- def reset (self , inputs , outputs , tols , mask , targ_sp_times ):
239
+ def reset (self , inputs , outputs , tols , mask , clip_mask , targ_sp_times ):
229
240
self .inputs .set (inputs )
230
241
self .outputs .set (outputs )
231
242
self .tols .set (tols )
232
243
self .mask .set (mask )
244
+ self .clip_mask .set (clip_mask )
233
245
self .targ_sp_times .set (targ_sp_times )
234
246
235
247
def save (self , directory , ** kwargs ):
0 commit comments