Skip to content

Commit

Permalink
mod to gated-hebb
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 1, 2024
1 parent 94b2453 commit 853b93a
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions ngclearn/components/synapses/hebbian/gatedHebbianSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class GatedHebbianSynapse(DenseSynapse):

# Define Functions
def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
w_bound=1., w_decay=0., p_conn=1., resist_scale=1.,
w_bound=1., w_decay=0., alpha=0., p_conn=1., resist_scale=1.,
batch_size=1, **kwargs):
super().__init__(name, shape, weight_init, bias_init, resist_scale,
p_conn, batch_size=batch_size, **kwargs)
Expand All @@ -17,6 +17,7 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
self.w_bound = w_bound
self.w_decay = w_decay ## synaptic decay
self.eta = eta
self.alpha = alpha

# compartments (state of the cell, parameters, will be updated through stateless calls)
self.preVals = jnp.zeros((self.batch_size, shape[0]))
Expand All @@ -29,19 +30,19 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
self.dBiases = Compartment(jnp.zeros(shape[1]))

@staticmethod
def _compute_update(w_bound, pre, post, weights):
def _compute_update(alpha, pre, post, weights):
## calculate synaptic update values
dW = jnp.matmul(pre.T, post)
db = jnp.sum(post, axis=0, keepdims=True)
# if w_bound > 0.:
# dW = dW * (w_bound - jnp.abs(weights))
if alpha > 0.: ## apply synaptic dependency weighting
dW = dW * (alpha - jnp.abs(weights))
return dW, db

@staticmethod
def _evolve(bias_init, eta, w_decay, w_bound, pre, post, preSpike,
def _evolve(bias_init, eta, alpha, w_decay, w_bound, pre, post, preSpike,
postSpike, weights, biases):
## calculate synaptic update values
dWeights, dBiases = GatedHebbianSynapse._compute_update(w_bound, pre, post, weights)
dWeights, dBiases = GatedHebbianSynapse._compute_update(alpha, pre, post, weights)
weights = weights + dWeights * eta
if w_decay > 0.:
Wdec = jnp.matmul((1. - preSpike).T, postSpike) * w_decay
Expand Down

0 comments on commit 853b93a

Please sign in to comment.