@@ -162,7 +162,7 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
162
162
thresholdType , thr_lmbda = threshold
163
163
self .thresholdType = thresholdType ## type of thresholding function to use
164
164
self .thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics
165
- self .Rscale = resist_scale ## a "resistance" scaling factor
165
+ self .resist_scale = resist_scale ## a "resistance" scaling factor
166
166
167
167
## integration properties
168
168
self .integrationType = integration_type
@@ -188,7 +188,7 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
188
188
189
189
@staticmethod
190
190
def _advance_state (dt , fx , dfx , tau_m , priorLeakRate , intgFlag , priorType ,
191
- Rscale , thresholdType , thr_lmbda , is_stateful , j , j_td , z ):
191
+ resist_scale , thresholdType , thr_lmbda , is_stateful , j , j_td , z ):
192
192
#if tau_m > 0.:
193
193
if is_stateful :
194
194
### run a step of integration over neuronal dynamics
@@ -197,7 +197,7 @@ def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
197
197
## self.current <-- "bottom-up" data-dependent signal
198
198
dfx_val = dfx (z )
199
199
j = _modulate (j , dfx_val )
200
- j = j * Rscale
200
+ j = j * resist_scale
201
201
tmp_z = _run_cell (dt , j , j_td , z ,
202
202
tau_m , leak_gamma = priorLeakRate ,
203
203
integType = intgFlag , priorType = priorType )
@@ -237,6 +237,30 @@ def reset(self, j, zF, j_td, z):
237
237
self .j_td .set (j_td ) # top-down electrical current - pressure
238
238
self .z .set (z ) # rate activity
239
239
240
+ def save (self , directory , ** kwargs ):
241
+ ## do a protected save of constants, depending on whether they are floats or arrays
242
+ tau_m = (self .tau_m if isinstance (self .tau_m , float )
243
+ else jnp .ones ([[self .tau_m ]]))
244
+ priorLeakRate = (self .priorLeakRate if isinstance (self .priorLeakRate , float )
245
+ else jnp .ones ([[self .priorLeakRate ]]))
246
+ resist_scale = (self .resist_scale if isinstance (self .resist_scale , float )
247
+ else jnp .ones ([[self .resist_scale ]]))
248
+
249
+ file_name = directory + "/" + self .name + ".npz"
250
+ jnp .savez (file_name ,
251
+ tau_m = tau_m , priorLeakRate = priorLeakRate ,
252
+ resist_scale = resist_scale ) #, key=self.key.value)
253
+
254
+ def load (self , directory , seeded = False , ** kwargs ):
255
+ file_name = directory + "/" + self .name + ".npz"
256
+ data = jnp .load (file_name )
257
+ ## constants loaded in
258
+ self .tau_m = data ['tau_m' ]
259
+ self .priorLeakRate = data ['priorLeakRate' ]
260
+ self .resist_scale = data ['resist_scale' ]
261
+ #if seeded:
262
+ # self.key.set(data['key'])
263
+
240
264
@classmethod
241
265
def help (cls ): ## component help function
242
266
properties = {
0 commit comments