Skip to content

Commit 19bf502

Browse files
committed
fixed minor saving/loading in rate-cell w/ vectorized compartments
1 parent 086bd4d commit 19bf502

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

ngclearn/components/neurons/graded/rateCell.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
162162
thresholdType, thr_lmbda = threshold
163163
self.thresholdType = thresholdType ## type of thresholding function to use
164164
self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics
165-
self.Rscale = resist_scale ## a "resistance" scaling factor
165+
self.resist_scale = resist_scale ## a "resistance" scaling factor
166166

167167
## integration properties
168168
self.integrationType = integration_type
@@ -188,7 +188,7 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
188188

189189
@staticmethod
190190
def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
191-
Rscale, thresholdType, thr_lmbda, is_stateful, j, j_td, z):
191+
resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z):
192192
#if tau_m > 0.:
193193
if is_stateful:
194194
### run a step of integration over neuronal dynamics
@@ -197,7 +197,7 @@ def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
197197
## self.current <-- "bottom-up" data-dependent signal
198198
dfx_val = dfx(z)
199199
j = _modulate(j, dfx_val)
200-
j = j * Rscale
200+
j = j * resist_scale
201201
tmp_z = _run_cell(dt, j, j_td, z,
202202
tau_m, leak_gamma=priorLeakRate,
203203
integType=intgFlag, priorType=priorType)
@@ -237,6 +237,30 @@ def reset(self, j, zF, j_td, z):
237237
self.j_td.set(j_td) # top-down electrical current - pressure
238238
self.z.set(z) # rate activity
239239

240+
def save(self, directory, **kwargs):
241+
## do a protected save of constants, depending on whether they are floats or arrays
242+
tau_m = (self.tau_m if isinstance(self.tau_m, float)
243+
else jnp.ones([[self.tau_m]]))
244+
priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
245+
else jnp.ones([[self.priorLeakRate]]))
246+
resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
247+
else jnp.ones([[self.resist_scale]]))
248+
249+
file_name = directory + "/" + self.name + ".npz"
250+
jnp.savez(file_name,
251+
tau_m=tau_m, priorLeakRate=priorLeakRate,
252+
resist_scale=resist_scale) #, key=self.key.value)
253+
254+
def load(self, directory, seeded=False, **kwargs):
255+
file_name = directory + "/" + self.name + ".npz"
256+
data = jnp.load(file_name)
257+
## constants loaded in
258+
self.tau_m = data['tau_m']
259+
self.priorLeakRate = data['priorLeakRate']
260+
self.resist_scale = data['resist_scale']
261+
#if seeded:
262+
# self.key.set(data['key'])
263+
240264
@classmethod
241265
def help(cls): ## component help function
242266
properties = {

0 commit comments

Comments
 (0)