Skip to content

Commit

Permalink
added threshold-clipping to latency cell
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 25, 2024
1 parent bf72094 commit c894b8a
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions ngclearn/components/input_encoders/latencyCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,16 @@ class LatencyCell(JaxComponent):
:Note: if this set to True, you will need to choose a useful value
for the "num_steps" argument (>1), depending on how many steps simulated
clip_spikes: should values under threshold be removed/suppressed?
(default: False)
num_steps: number of discrete time steps to consider for normalized latency
code (only useful if "normalize" is set to True) (Default: 1)
"""

# Define Functions
def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
linearize=False, normalize=False, num_steps=1.,
linearize=False, normalize=False, clip_spikes=False, num_steps=1.,
batch_size=1, **kwargs):
super().__init__(name, **kwargs)

Expand All @@ -164,6 +167,7 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
self.tau = tau
self.threshold = threshold
self.linearize = linearize
self.clip_spikes = clip_spikes
## normalize latency code s.t. final spike(s) occur w/in num_steps
self.normalize = normalize
self.num_steps = num_steps
Expand All @@ -176,17 +180,22 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
restVals = jnp.zeros((self.batch_size, self.n_units))
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.mask = Compartment(restVals, display_name="Mask Variable") # output compartment
self.mask = Compartment(restVals, display_name="Spike Time Mask")
self.clip_mask = Compartment(restVals, display_name="Clip Mask")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
#self.reset()

@staticmethod
def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
normalize, inputs):
normalize, clip_spikes, inputs):
## would call this function before processing a spike train (at start)
data = inputs
if linearize == True: ## linearize spike time calculation
if clip_spikes:
clip_mask = (data < threshold) * 1. ## find values under threshold
else:
clip_mask = data * 0. ## all values allowed to fire spikes
if linearize: ## linearize spike time calculation
stimes = _calc_spike_times_linear(data, tau, threshold,
first_spike_time,
num_steps, normalize)
Expand All @@ -197,18 +206,20 @@ def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
num_steps=num_steps,
normalize=normalize)
targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
return targ_sp_times
return targ_sp_times, clip_mask

@resolver(_calc_spike_times)
def calc_spike_times(self, targ_sp_times):
def calc_spike_times(self, targ_sp_times, clip_mask):
self.targ_sp_times.set(targ_sp_times)
self.clip_mask.set(clip_mask)

@staticmethod
def _advance_state(t, dt, key, inputs, mask, targ_sp_times, tols):
def _advance_state(t, dt, key, inputs, mask, clip_mask, targ_sp_times, tols):
key, *subkeys = random.split(key, 2)
data = inputs ## get sensory pattern data / features
spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t
tols = _update_times(t, spikes, tols)
spikes = spikes * (1. - clip_mask)
return spikes, tols, spk_mask, targ_sp_times, key

@resolver(_advance_state)
Expand All @@ -222,14 +233,15 @@ def advance_state(self, outputs, tols, mask, targ_sp_times, key):
@staticmethod
def _reset(batch_size, n_units):
restVals = jnp.zeros((batch_size, n_units))
return (restVals, restVals, restVals, restVals, restVals)
return (restVals, restVals, restVals, restVals, restVals, restVals)

@resolver(_reset)
def reset(self, inputs, outputs, tols, mask, targ_sp_times):
def reset(self, inputs, outputs, tols, mask, clip_mask, targ_sp_times):
self.inputs.set(inputs)
self.outputs.set(outputs)
self.tols.set(tols)
self.mask.set(mask)
self.clip_mask.set(clip_mask)
self.targ_sp_times.set(targ_sp_times)

def save(self, directory, **kwargs):
Expand Down

0 comments on commit c894b8a

Please sign in to comment.