From ee4e401b143476a4d444fe2d0435ee717abdc273 Mon Sep 17 00:00:00 2001 From: ago109 Date: Mon, 1 Jul 2024 16:12:00 -0400 Subject: [PATCH] minor cleanup of hebb syn --- .../synapses/hebbian/hebbianSynapse.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index dbc3d5445..3876c1154 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -220,6 +220,15 @@ def _reset(batch_size, shape): jnp.zeros(shape[1]), # db ) + @resolver(_reset) + def reset(self, inputs, outputs, pre, post, dWeights, dBiases): + self.inputs.set(inputs) + self.outputs.set(outputs) + self.pre.set(pre) + self.post.set(post) + self.dWeights.set(dWeights) + self.dBiases.set(dBiases) + @classmethod def help(cls): ## component help function properties = { @@ -265,15 +274,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - @resolver(_reset) - def reset(self, inputs, outputs, pre, post, dWeights, dBiases): - self.inputs.set(inputs) - self.outputs.set(outputs) - self.pre.set(pre) - self.post.set(post) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - def __repr__(self): comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] maxlen = max(len(c) for c in comps) + 5