Skip to content

Commit 086bd4d

Browse files
committed
added meta-data to rate-cell, input encoders, adex
1 parent e240644 commit 086bd4d

File tree

5 files changed

+21
-20
lines changed

5 files changed

+21
-20
lines changed

Diff for: ngclearn/components/input_encoders/bernoulliCell.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
6565

6666
# Compartments (state of the cell, parameters, will be updated through stateless calls)
6767
restVals = jnp.zeros((self.batch_size, self.n_units))
68-
self.inputs = Compartment(restVals) # input compartment
69-
self.outputs = Compartment(restVals) # output compartment
70-
self.tols = Compartment(restVals) # time of last spike
68+
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
69+
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
70+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
7171

7272
@staticmethod
7373
def _advance_state(t, key, inputs, tols):

Diff for: ngclearn/components/input_encoders/latencyCell.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
173173

174174
## Compartment setup
175175
restVals = jnp.zeros((self.batch_size, self.n_units))
176-
self.inputs = Compartment(restVals) # input compartment
177-
self.outputs = Compartment(restVals) # output compartment
178-
self.mask = Compartment(restVals) # output compartment
179-
self.tols = Compartment(restVals) # time of last spike
180-
self.targ_sp_times = Compartment(restVals)
176+
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
177+
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
178+
self.mask = Compartment(restVals, display_name="Mask Variable") # output compartment
179+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
180+
self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
181181
#self.reset()
182182

183183
@staticmethod

Diff for: ngclearn/components/input_encoders/poissonCell.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
7777

7878
## Compartment setup
7979
restVals = jnp.zeros((self.batch_size, self.n_units))
80-
self.inputs = Compartment(restVals) # input compartment
81-
self.outputs = Compartment(restVals) # output compartment
82-
self.tols = Compartment(restVals) # time of last spike
80+
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
81+
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
82+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
8383

8484
@staticmethod
8585
def _advance_state(t, dt, max_freq, key, inputs, tols):

Diff for: ngclearn/components/neurons/graded/rateCell.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
181181

182182
# compartments (state of the cell & parameters will be updated through stateless calls)
183183
restVals = jnp.zeros(_shape)
184-
self.j = Compartment(restVals) # electrical current
185-
self.zF = Compartment(restVals) # rate-coded output - activity
186-
self.j_td = Compartment(restVals) # top-down electrical current - pressure
187-
self.z = Compartment(restVals) # rate activity
184+
self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
185+
self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity
186+
self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
187+
self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
188188

189189
@staticmethod
190190
def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,

Diff for: ngclearn/components/neurons/spiking/adExCell.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,12 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
166166

167167
## Compartment setup
168168
restVals = jnp.zeros((self.batch_size, self.n_units))
169-
self.j = Compartment(restVals)
170-
self.v = Compartment(restVals + self.v0)
171-
self.w = Compartment(restVals + self.w0)
172-
self.s = Compartment(restVals)
173-
self.tols = Compartment(restVals) ## time-of-last-spike
169+
self.j = Compartment(restVals, display_name="Current", units="mA")
170+
self.v = Compartment(restVals + self.v0, display_name="Voltage", units="mV")
171+
self.w = Compartment(restVals + self.w0, display_name="Recovery")
172+
self.s = Compartment(restVals, display_name="Spikes")
173+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
174+
units="ms") ## time-of-last-spike
174175

175176
@staticmethod
176177
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, a, b, sharpV, vT,

0 commit comments

Comments
 (0)