@@ -279,21 +279,21 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
279
279
def save (self , directory , ** kwargs ):
280
280
## do a protected save of constants, depending on whether they are floats or arrays
281
281
tau_m = (self .tau_m if isinstance (self .tau_m , float )
282
- else jnp .ones ([[self .tau_m ]]))
282
+ else jnp .asarray ([[self .tau_m * 1. ]]))
283
283
thr = (self .thr if isinstance (self .thr , float )
284
- else jnp .ones ([[self .thr ]]))
284
+ else jnp .asarray ([[self .thr * 1. ]]))
285
285
v_rest = (self .v_rest if isinstance (self .v_rest , float )
286
- else jnp .ones ([[self .v_rest ]]))
286
+ else jnp .asarray ([[self .v_rest * 1. ]]))
287
287
v_reset = (self .v_reset if isinstance (self .v_reset , float )
288
- else jnp .ones ([[self .v_reset ]]))
288
+ else jnp .asarray ([[self .v_reset * 1. ]]))
289
289
v_decay = (self .v_decay if isinstance (self .v_decay , float )
290
- else jnp .ones ([[self .v_decay ]]))
290
+ else jnp .asarray ([[self .v_decay * 1. ]]))
291
291
resist_m = (self .resist_m if isinstance (self .resist_m , float )
292
- else jnp .ones ([[self .resist_m ]]))
292
+ else jnp .asarray ([[self .resist_m * 1. ]]))
293
293
tau_theta = (self .tau_theta if isinstance (self .tau_theta , float )
294
- else jnp .ones ([[self .tau_theta ]]))
294
+ else jnp .asarray ([[self .tau_theta * 1. ]]))
295
295
theta_plus = (self .theta_plus if isinstance (self .theta_plus , float )
296
- else jnp .ones ([[self .theta_plus ]]))
296
+ else jnp .asarray ([[self .theta_plus * 1. ]]))
297
297
298
298
file_name = directory + "/" + self .name + ".npz"
299
299
jnp .savez (file_name ,
0 commit comments