@@ -61,9 +61,9 @@ class RAFCell(JaxComponent):
61
61
The specific pair of differential equations that characterize this cell
62
62
are (for adjusting v and w, given current j, over time):
63
63
64
- | tau_m * dv/dt = -(v - v_rest) + sharpV * exp((v - vT)/sharpV) - R_m * w + R_m * j
65
- | tau_w * dw/dt = -w + (v - v_rest) * a
66
- | where w = w + s * (w + b) [in the event of a spike]
64
+ | tau_m * dv/dt = omega * w + v * b
65
+ | tau_w * dw/dt = w * b - v * omega + j
66
+ | where omega is angular frequency (Hz) and b is exponential dampening factor
67
67
68
68
| --- Cell Input Compartments: ---
69
69
| j - electrical current input (takes in external signals)
@@ -93,13 +93,11 @@ class RAFCell(JaxComponent):
93
93
thr: voltage/membrane threshold (to obtain action potentials in terms
94
94
of binary spikes) (Default: 5 mV)
95
95
96
- v_rest : membrane resting potential (Default: -72 mV)
96
+ v_reset : membrane reset potential condition (Default: 0 mV)
97
97
98
- b: oscillation dampening factor (Default: -1.)
99
-
100
- v0: initial condition / reset for voltage (Default: -70 mV)
98
+ w_reset: reset condition for angular driver (Default: 0 mV)
101
99
102
- w0: initial condition / reset for angular driver (Default: 0 mV )
100
+ b: oscillation dampening factor (Default: -1. )
103
101
104
102
integration_type: type of integration to use for this cell's dynamics;
105
103
current supported forms include "euler" (Euler/RK-1 integration)
@@ -112,9 +110,9 @@ class RAFCell(JaxComponent):
112
110
113
111
# Define Functions
114
112
def __init__ (self , name , n_units , tau_m = 15. , resist_m = 1. , tau_w = 400. ,
115
- omega = 10. , thr = 5. , v_rest = - 72. ,
116
- v_reset = - 75. , w_reset = 0. , b = - 1. , v0 = - 70. , w0 = 0. ,
113
+ omega = 10. , thr = 5. , v_reset = 0. , w_reset = 0. , b = - 1. ,
117
114
integration_type = "euler" , batch_size = 1 , ** kwargs ):
115
+ #v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0.,
118
116
super ().__init__ (name , ** kwargs )
119
117
120
118
## Integration properties
@@ -128,12 +126,9 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
128
126
self .omega = omega ## angular frequency
129
127
self .b = b ## dampening factor
130
128
## note: the smaller b is, the faster the oscillation dampens to resting state values
131
- self .v_rest = v_rest
129
+ # self.v_rest = v_rest
132
130
self .v_reset = v_reset
133
131
self .w_reset = w_reset
134
-
135
- self .v0 = v0 ## initial membrane potential/voltage condition
136
- self .w0 = w0 ## initial w-parameter condition
137
132
self .thr = thr
138
133
139
134
## Layer Size Setup
@@ -150,8 +145,12 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
150
145
units = "ms" ) ## time-of-last-spike
151
146
152
147
@staticmethod
153
- def _advance_state (t , dt , tau_m , resist_m , tau_w , thr , omega , b , v_rest ,
148
+ def _advance_state (t , dt , tau_m , resist_m , tau_w , thr , omega , b ,
154
149
v_reset , w_reset , intgFlag , j , v , w , tols ):
150
+ ## center variables before running dynamics
151
+ v = v - v_reset
152
+ w = w - w_reset
153
+ ## continue with centered dynamics
155
154
j_ = j * resist_m
156
155
if intgFlag == 1 : ## RK-2/midpoint
157
156
w_params = (j_ , v , tau_w , omega , b )
@@ -165,9 +164,11 @@ def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b, v_rest,
165
164
_ , _v = step_euler (0. , v , _dfv , dt , v_params )
166
165
s = _emit_spike (_v , thr )
167
166
## hyperpolarize/reset/snap variables
168
- v = _v * (1. - s ) + s * v_reset
169
- w = _w * (1. - s ) + s * w_reset
170
-
167
+ v = _v * (1. - s ) + s #* v_reset
168
+ w = _w * (1. - s ) + s #* w_reset
169
+ ## artificially shift variables back to rest/reset values
170
+ v = v + v_reset
171
+ w = w + w_reset
171
172
tols = _update_times (t , s , tols )
172
173
return j , v , w , s , tols
173
174
@@ -180,11 +181,11 @@ def advance_state(self, j, v, w, s, tols):
180
181
self .tols .set (tols )
181
182
182
183
@staticmethod
183
- def _reset (batch_size , n_units , v0 , w0 ):
184
+ def _reset (batch_size , n_units , v_reset , w_reset ):
184
185
restVals = jnp .zeros ((batch_size , n_units ))
185
186
j = restVals # None
186
- v = restVals + v0
187
- w = restVals + w0
187
+ v = restVals + v_reset
188
+ w = restVals + w_reset
188
189
s = restVals #+ 0
189
190
tols = restVals #+ 0
190
191
return j , v , w , s , tols
0 commit comments