From 086bd4da9f241bed0d07c1e0b6ad39a79f97839c Mon Sep 17 00:00:00 2001 From: ago109 Date: Tue, 23 Jul 2024 15:18:31 -0400 Subject: [PATCH] added meta-data to rate-cell, input encoders, adex --- ngclearn/components/input_encoders/bernoulliCell.py | 6 +++--- ngclearn/components/input_encoders/latencyCell.py | 10 +++++----- ngclearn/components/input_encoders/poissonCell.py | 6 +++--- ngclearn/components/neurons/graded/rateCell.py | 8 ++++---- ngclearn/components/neurons/spiking/adExCell.py | 11 ++++++----- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 3ad409b8..da090bfb 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -65,9 +65,9 @@ def __init__(self, name, n_units, batch_size=1, **kwargs): # Compartments (state of the cell, parameters, will be updated through stateless calls) restVals = jnp.zeros((self.batch_size, self.n_units)) - self.inputs = Compartment(restVals) # input compartment - self.outputs = Compartment(restVals) # output compartment - self.tols = Compartment(restVals) # time of last spike + self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment + self.outputs = Compartment(restVals, display_name="Spikes") # output compartment + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike @staticmethod def _advance_state(t, key, inputs, tols): diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index 0a3045cc..1acf199f 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -173,11 +173,11 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0., ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) - self.inputs = Compartment(restVals) # input compartment - self.outputs = Compartment(restVals) # output compartment - self.mask = Compartment(restVals) # output compartment - self.tols = Compartment(restVals) # time of last spike - self.targ_sp_times = Compartment(restVals) + self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment + self.outputs = Compartment(restVals, display_name="Spikes") # output compartment + self.mask = Compartment(restVals, display_name="Mask Variable") # output compartment + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike + self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms") #self.reset() @staticmethod diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index cc1a3c8e..115afbc9 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -77,9 +77,9 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs): ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) - self.inputs = Compartment(restVals) # input compartment - self.outputs = Compartment(restVals) # output compartment - self.tols = Compartment(restVals) # time of last spike + self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment + self.outputs = Compartment(restVals, display_name="Spikes") # output compartment + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike @staticmethod def _advance_state(t, dt, max_freq, key, inputs, tols): diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 748a5d4c..80ceea6d 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -181,10 +181,10 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit # compartments (state of the cell & parameters will be updated through stateless calls) restVals = jnp.zeros(_shape) - self.j = Compartment(restVals) # electrical current - self.zF = Compartment(restVals) # rate-coded output - activity - self.j_td = Compartment(restVals) # top-down electrical current - pressure - self.z = Compartment(restVals) # rate activity + self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current + self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity + self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure + self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity @staticmethod def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index 83680863..ef0c9450 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -166,11 +166,12 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400., ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) - self.j = Compartment(restVals) - self.v = Compartment(restVals + self.v0) - self.w = Compartment(restVals + self.w0) - self.s = Compartment(restVals) - self.tols = Compartment(restVals) ## time-of-last-spike + self.j = Compartment(restVals, display_name="Current", units="mA") + self.v = Compartment(restVals + self.v0, display_name="Voltage", units="mV") + self.w = Compartment(restVals + self.w0, display_name="Recovery") + self.s = Compartment(restVals, display_name="Spikes") + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", + units="ms") ## time-of-last-spike @staticmethod def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, a, b, sharpV, vT,