diff --git a/ngclearn/components/synapses/hebbian/gatedHebbianSynapse.py b/ngclearn/components/synapses/hebbian/gatedHebbianSynapse.py index 3d17f77cf..a3e21103d 100755 --- a/ngclearn/components/synapses/hebbian/gatedHebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/gatedHebbianSynapse.py @@ -7,7 +7,7 @@ class GatedHebbianSynapse(DenseSynapse): # Define Functions def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None, - w_bound=1., w_decay=0., p_conn=1., resist_scale=1., + w_bound=1., w_decay=0., alpha=0., p_conn=1., resist_scale=1., batch_size=1, **kwargs): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs) @@ -17,6 +17,7 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None, self.w_bound = w_bound self.w_decay = w_decay ## synaptic decay self.eta = eta + self.alpha = alpha # compartments (state of the cell, parameters, will be updated through stateless calls) self.preVals = jnp.zeros((self.batch_size, shape[0])) @@ -29,19 +30,19 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None, self.dBiases = Compartment(jnp.zeros(shape[1])) @staticmethod - def _compute_update(w_bound, pre, post, weights): + def _compute_update(alpha, pre, post, weights): ## calculate synaptic update values dW = jnp.matmul(pre.T, post) db = jnp.sum(post, axis=0, keepdims=True) - # if w_bound > 0.: - # dW = dW * (w_bound - jnp.abs(weights)) + if alpha > 0.: ## apply synaptic dependency weighting + dW = dW * (alpha - jnp.abs(weights)) return dW, db @staticmethod - def _evolve(bias_init, eta, w_decay, w_bound, pre, post, preSpike, + def _evolve(bias_init, eta, alpha, w_decay, w_bound, pre, post, preSpike, postSpike, weights, biases): ## calculate synaptic update values - dWeights, dBiases = GatedHebbianSynapse._compute_update(w_bound, pre, post, weights) + dWeights, dBiases = GatedHebbianSynapse._compute_update(alpha, pre, post, weights) weights = weights + dWeights * eta if w_decay > 0.: Wdec = jnp.matmul((1. - preSpike).T, postSpike) * w_decay