Skip to content

Commit 8882208

Browse files
committed
cleaned up raf
1 parent dd49e5f commit 8882208

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

ngclearn/components/neurons/spiking/RAFCell.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ class RAFCell(JaxComponent):
6161
The specific pair of differential equations that characterize this cell
6262
are (for adjusting v and w, given current j, over time):
6363
64-
| tau_m * dv/dt = -(v - v_rest) + sharpV * exp((v - vT)/sharpV) - R_m * w + R_m * j
65-
| tau_w * dw/dt = -w + (v - v_rest) * a
66-
| where w = w + s * (w + b) [in the event of a spike]
64+
| tau_m * dv/dt = omega * w + v * b
65+
| tau_w * dw/dt = w * b - v * omega + j
66+
| where omega is angular frequency (Hz) and b is exponential dampening factor
6767
6868
| --- Cell Input Compartments: ---
6969
| j - electrical current input (takes in external signals)
@@ -93,13 +93,11 @@ class RAFCell(JaxComponent):
9393
thr: voltage/membrane threshold (to obtain action potentials in terms
9494
of binary spikes) (Default: 5 mV)
9595
96-
v_rest: membrane resting potential (Default: -72 mV)
96+
v_reset: membrane reset potential condition (Default: 0 mV)
9797
98-
b: oscillation dampening factor (Default: -1.)
99-
100-
v0: initial condition / reset for voltage (Default: -70 mV)
98+
w_reset: reset condition for angular driver (Default: 0 mV)
10199
102-
w0: initial condition / reset for angular driver (Default: 0 mV)
100+
b: oscillation dampening factor (Default: -1.)
103101
104102
integration_type: type of integration to use for this cell's dynamics;
105103
current supported forms include "euler" (Euler/RK-1 integration)
@@ -112,9 +110,9 @@ class RAFCell(JaxComponent):
112110

113111
# Define Functions
114112
def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
115-
omega=10., thr=5., v_rest=-72.,
116-
v_reset=-75., w_reset=0., b=-1., v0=-70., w0=0.,
113+
omega=10., thr=5., v_reset=0., w_reset=0., b=-1.,
117114
integration_type="euler", batch_size=1, **kwargs):
115+
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0.,
118116
super().__init__(name, **kwargs)
119117

120118
## Integration properties
@@ -128,12 +126,9 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
128126
self.omega = omega ## angular frequency
129127
self.b = b ## dampening factor
130128
## note: the smaller b is, the faster the oscillation dampens to resting state values
131-
self.v_rest = v_rest
129+
#self.v_rest = v_rest
132130
self.v_reset = v_reset
133131
self.w_reset = w_reset
134-
135-
self.v0 = v0 ## initial membrane potential/voltage condition
136-
self.w0 = w0 ## initial w-parameter condition
137132
self.thr = thr
138133

139134
## Layer Size Setup
@@ -150,8 +145,12 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
150145
units="ms") ## time-of-last-spike
151146

152147
@staticmethod
153-
def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b, v_rest,
148+
def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b,
154149
v_reset, w_reset, intgFlag, j, v, w, tols):
150+
## center variables before running dynamics
151+
v = v - v_reset
152+
w = w - w_reset
153+
## continue with centered dynamics
155154
j_ = j * resist_m
156155
if intgFlag == 1: ## RK-2/midpoint
157156
w_params = (j_, v, tau_w, omega, b)
@@ -165,9 +164,11 @@ def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b, v_rest,
165164
_, _v = step_euler(0., v, _dfv, dt, v_params)
166165
s = _emit_spike(_v, thr)
167166
## hyperpolarize/reset/snap variables
168-
v = _v * (1. - s) + s * v_reset
169-
w = _w * (1. - s) + s * w_reset
170-
167+
v = _v * (1. - s) + s #* v_reset
168+
w = _w * (1. - s) + s #* w_reset
169+
## artificially shift variables back to rest/reset values
170+
v = v + v_reset
171+
w = w + w_reset
171172
tols = _update_times(t, s, tols)
172173
return j, v, w, s, tols
173174

@@ -180,11 +181,11 @@ def advance_state(self, j, v, w, s, tols):
180181
self.tols.set(tols)
181182

182183
@staticmethod
183-
def _reset(batch_size, n_units, v0, w0):
184+
def _reset(batch_size, n_units, v_reset, w_reset):
184185
restVals = jnp.zeros((batch_size, n_units))
185186
j = restVals # None
186-
v = restVals + v0
187-
w = restVals + w0
187+
v = restVals + v_reset
188+
w = restVals + w_reset
188189
s = restVals #+ 0
189190
tols = restVals #+ 0
190191
return j, v, w, s, tols

0 commit comments

Comments
 (0)