Skip to content

Commit

Permalink
added meta-data to rate-cell, input encoders, adex
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 23, 2024
1 parent e240644 commit 086bd4d
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
6 changes: 3 additions & 3 deletions ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):

# Compartments (state of the cell, parameters, will be updated through stateless calls)
restVals = jnp.zeros((self.batch_size, self.n_units))
self.inputs = Compartment(restVals) # input compartment
self.outputs = Compartment(restVals) # output compartment
self.tols = Compartment(restVals) # time of last spike
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike

@staticmethod
def _advance_state(t, key, inputs, tols):
Expand Down
10 changes: 5 additions & 5 deletions ngclearn/components/input_encoders/latencyCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,

## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.inputs = Compartment(restVals) # input compartment
self.outputs = Compartment(restVals) # output compartment
self.mask = Compartment(restVals) # output compartment
self.tols = Compartment(restVals) # time of last spike
self.targ_sp_times = Compartment(restVals)
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.mask = Compartment(restVals, display_name="Mask Variable") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
#self.reset()

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions ngclearn/components/input_encoders/poissonCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):

## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.inputs = Compartment(restVals) # input compartment
self.outputs = Compartment(restVals) # output compartment
self.tols = Compartment(restVals) # time of last spike
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike

@staticmethod
def _advance_state(t, dt, max_freq, key, inputs, tols):
Expand Down
8 changes: 4 additions & 4 deletions ngclearn/components/neurons/graded/rateCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit

# compartments (state of the cell & parameters will be updated through stateless calls)
restVals = jnp.zeros(_shape)
self.j = Compartment(restVals) # electrical current
self.zF = Compartment(restVals) # rate-coded output - activity
self.j_td = Compartment(restVals) # top-down electrical current - pressure
self.z = Compartment(restVals) # rate activity
self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity
self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity

@staticmethod
def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
Expand Down
11 changes: 6 additions & 5 deletions ngclearn/components/neurons/spiking/adExCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,12 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,

## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.j = Compartment(restVals)
self.v = Compartment(restVals + self.v0)
self.w = Compartment(restVals + self.w0)
self.s = Compartment(restVals)
self.tols = Compartment(restVals) ## time-of-last-spike
self.j = Compartment(restVals, display_name="Current", units="mA")
self.v = Compartment(restVals + self.v0, display_name="Voltage", units="mV")
self.w = Compartment(restVals + self.w0, display_name="Recovery")
self.s = Compartment(restVals, display_name="Spikes")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
units="ms") ## time-of-last-spike

@staticmethod
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, a, b, sharpV, vT,
Expand Down

0 comments on commit 086bd4d

Please sign in to comment.