Skip to content

Commit ba08453

Browse files
committed
updates to if/lif
1 parent c894b8a commit ba08453

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

ngclearn/components/neurons/spiking/IFCell.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -213,21 +213,21 @@ def reset(self, j, v, s, rfr, tols, surrogate):
213213
def save(self, directory, **kwargs):
214214
## do a protected save of constants, depending on whether they are floats or arrays
215215
tau_m = (self.tau_m if isinstance(self.tau_m, float)
216-
else jnp.ones([[self.tau_m]]))
216+
else jnp.asarray([[self.tau_m * 1.]]))
217217
thr = (self.thr if isinstance(self.thr, float)
218-
else jnp.ones([[self.thr]]))
218+
else jnp.asarray([[self.thr * 1.]]))
219219
v_rest = (self.v_rest if isinstance(self.v_rest, float)
220-
else jnp.ones([[self.v_rest]]))
220+
else jnp.asarray([[self.v_rest * 1.]]))
221221
v_reset = (self.v_reset if isinstance(self.v_reset, float)
222-
else jnp.ones([[self.v_reset]]))
222+
else jnp.asarray([[self.v_reset * 1.]]))
223223
v_decay = (self.v_decay if isinstance(self.v_decay, float)
224-
else jnp.ones([[self.v_decay]]))
224+
else jnp.asarray([[self.v_decay * 1.]]))
225225
resist_m = (self.resist_m if isinstance(self.resist_m, float)
226-
else jnp.ones([[self.resist_m]]))
226+
else jnp.asarray([[self.resist_m * 1.]]))
227227
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
228-
else jnp.ones([[self.tau_theta]]))
228+
else jnp.asarray([[self.tau_theta * 1.]]))
229229
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
230-
else jnp.ones([[self.theta_plus]]))
230+
else jnp.asarray([[self.theta_plus * 1.]]))
231231

232232
file_name = directory + "/" + self.name + ".npz"
233233
jnp.savez(file_name,

ngclearn/components/neurons/spiking/LIFCell.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -279,21 +279,21 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
279279
def save(self, directory, **kwargs):
280280
## do a protected save of constants, depending on whether they are floats or arrays
281281
tau_m = (self.tau_m if isinstance(self.tau_m, float)
282-
else jnp.ones([[self.tau_m]]))
282+
else jnp.asarray([[self.tau_m * 1.]]))
283283
thr = (self.thr if isinstance(self.thr, float)
284-
else jnp.ones([[self.thr]]))
284+
else jnp.asarray([[self.thr * 1.]]))
285285
v_rest = (self.v_rest if isinstance(self.v_rest, float)
286-
else jnp.ones([[self.v_rest]]))
286+
else jnp.asarray([[self.v_rest * 1.]]))
287287
v_reset = (self.v_reset if isinstance(self.v_reset, float)
288-
else jnp.ones([[self.v_reset]]))
288+
else jnp.asarray([[self.v_reset * 1.]]))
289289
v_decay = (self.v_decay if isinstance(self.v_decay, float)
290-
else jnp.ones([[self.v_decay]]))
290+
else jnp.asarray([[self.v_decay * 1.]]))
291291
resist_m = (self.resist_m if isinstance(self.resist_m, float)
292-
else jnp.ones([[self.resist_m]]))
292+
else jnp.asarray([[self.resist_m * 1.]]))
293293
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
294-
else jnp.ones([[self.tau_theta]]))
294+
else jnp.asarray([[self.tau_theta * 1.]]))
295295
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
296-
else jnp.ones([[self.theta_plus]]))
296+
else jnp.asarray([[self.theta_plus * 1.]]))
297297

298298
file_name = directory + "/" + self.name + ".npz"
299299
jnp.savez(file_name,

0 commit comments

Comments
 (0)