Skip to content

Commit c894b8a

Browse files
committed
added threshold-clipping to latency cell
1 parent bf72094 commit c894b8a

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

ngclearn/components/input_encoders/latencyCell.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,16 @@ class LatencyCell(JaxComponent):
149149
:Note: if this set to True, you will need to choose a useful value
150150
for the "num_steps" argument (>1), depending on how many steps simulated
151151
152+
clip_spikes: should values under threshold be removed/suppressed?
153+
(default: False)
154+
152155
num_steps: number of discrete time steps to consider for normalized latency
153156
code (only useful if "normalize" is set to True) (Default: 1)
154157
"""
155158

156159
# Define Functions
157160
def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
158-
linearize=False, normalize=False, num_steps=1.,
161+
linearize=False, normalize=False, clip_spikes=False, num_steps=1.,
159162
batch_size=1, **kwargs):
160163
super().__init__(name, **kwargs)
161164

@@ -164,6 +167,7 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
164167
self.tau = tau
165168
self.threshold = threshold
166169
self.linearize = linearize
170+
self.clip_spikes = clip_spikes
167171
## normalize latency code s.t. final spike(s) occur w/in num_steps
168172
self.normalize = normalize
169173
self.num_steps = num_steps
@@ -176,17 +180,22 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
176180
restVals = jnp.zeros((self.batch_size, self.n_units))
177181
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
178182
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
179-
self.mask = Compartment(restVals, display_name="Mask Variable") # output compartment
183+
self.mask = Compartment(restVals, display_name="Spike Time Mask")
184+
self.clip_mask = Compartment(restVals, display_name="Clip Mask")
180185
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
181186
self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
182187
#self.reset()
183188

184189
@staticmethod
185190
def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
186-
normalize, inputs):
191+
normalize, clip_spikes, inputs):
187192
## would call this function before processing a spike train (at start)
188193
data = inputs
189-
if linearize == True: ## linearize spike time calculation
194+
if clip_spikes:
195+
clip_mask = (data < threshold) * 1. ## find values under threshold
196+
else:
197+
clip_mask = data * 0. ## all values allowed to fire spikes
198+
if linearize: ## linearize spike time calculation
190199
stimes = _calc_spike_times_linear(data, tau, threshold,
191200
first_spike_time,
192201
num_steps, normalize)
@@ -197,18 +206,20 @@ def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
197206
num_steps=num_steps,
198207
normalize=normalize)
199208
targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
200-
return targ_sp_times
209+
return targ_sp_times, clip_mask
201210

202211
@resolver(_calc_spike_times)
203-
def calc_spike_times(self, targ_sp_times):
212+
def calc_spike_times(self, targ_sp_times, clip_mask):
204213
self.targ_sp_times.set(targ_sp_times)
214+
self.clip_mask.set(clip_mask)
205215

206216
@staticmethod
207-
def _advance_state(t, dt, key, inputs, mask, targ_sp_times, tols):
217+
def _advance_state(t, dt, key, inputs, mask, clip_mask, targ_sp_times, tols):
208218
key, *subkeys = random.split(key, 2)
209219
data = inputs ## get sensory pattern data / features
210220
spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t
211221
tols = _update_times(t, spikes, tols)
222+
spikes = spikes * (1. - clip_mask)
212223
return spikes, tols, spk_mask, targ_sp_times, key
213224

214225
@resolver(_advance_state)
@@ -222,14 +233,15 @@ def advance_state(self, outputs, tols, mask, targ_sp_times, key):
222233
@staticmethod
223234
def _reset(batch_size, n_units):
224235
restVals = jnp.zeros((batch_size, n_units))
225-
return (restVals, restVals, restVals, restVals, restVals)
236+
return (restVals, restVals, restVals, restVals, restVals, restVals)
226237

227238
@resolver(_reset)
228-
def reset(self, inputs, outputs, tols, mask, targ_sp_times):
239+
def reset(self, inputs, outputs, tols, mask, clip_mask, targ_sp_times):
229240
self.inputs.set(inputs)
230241
self.outputs.set(outputs)
231242
self.tols.set(tols)
232243
self.mask.set(mask)
244+
self.clip_mask.set(clip_mask)
233245
self.targ_sp_times.set(targ_sp_times)
234246

235247
def save(self, directory, **kwargs):

0 commit comments

Comments
 (0)