Skip to content

Commit 853b93a

Browse files
committed
mod to gated-hebb
1 parent 94b2453 commit 853b93a

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ngclearn/components/synapses/hebbian/gatedHebbianSynapse.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class GatedHebbianSynapse(DenseSynapse):
77

88
# Define Functions
99
def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
10-
w_bound=1., w_decay=0., p_conn=1., resist_scale=1.,
10+
w_bound=1., w_decay=0., alpha=0., p_conn=1., resist_scale=1.,
1111
batch_size=1, **kwargs):
1212
super().__init__(name, shape, weight_init, bias_init, resist_scale,
1313
p_conn, batch_size=batch_size, **kwargs)
@@ -17,6 +17,7 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
1717
self.w_bound = w_bound
1818
self.w_decay = w_decay ## synaptic decay
1919
self.eta = eta
20+
self.alpha = alpha
2021

2122
# compartments (state of the cell, parameters, will be updated through stateless calls)
2223
self.preVals = jnp.zeros((self.batch_size, shape[0]))
@@ -29,19 +30,19 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
2930
self.dBiases = Compartment(jnp.zeros(shape[1]))
3031

3132
@staticmethod
32-
def _compute_update(w_bound, pre, post, weights):
33+
def _compute_update(alpha, pre, post, weights):
3334
## calculate synaptic update values
3435
dW = jnp.matmul(pre.T, post)
3536
db = jnp.sum(post, axis=0, keepdims=True)
36-
# if w_bound > 0.:
37-
# dW = dW * (w_bound - jnp.abs(weights))
37+
if alpha > 0.: ## apply synaptic dependency weighting
38+
dW = dW * (alpha - jnp.abs(weights))
3839
return dW, db
3940

4041
@staticmethod
41-
def _evolve(bias_init, eta, w_decay, w_bound, pre, post, preSpike,
42+
def _evolve(bias_init, eta, alpha, w_decay, w_bound, pre, post, preSpike,
4243
postSpike, weights, biases):
4344
## calculate synaptic update values
44-
dWeights, dBiases = GatedHebbianSynapse._compute_update(w_bound, pre, post, weights)
45+
dWeights, dBiases = GatedHebbianSynapse._compute_update(alpha, pre, post, weights)
4546
weights = weights + dWeights * eta
4647
if w_decay > 0.:
4748
Wdec = jnp.matmul((1. - preSpike).T, postSpike) * w_decay

0 commit comments

Comments
 (0)