@@ -7,7 +7,7 @@ class GatedHebbianSynapse(DenseSynapse):
77
88 # Define Functions
99 def __init__ (self , name , shape , eta = 0. , weight_init = None , bias_init = None ,
10- w_bound = 1. , w_decay = 0. , p_conn = 1. , resist_scale = 1. ,
10+ w_bound = 1. , w_decay = 0. , alpha = 0. , p_conn = 1. , resist_scale = 1. ,
1111 batch_size = 1 , ** kwargs ):
1212 super ().__init__ (name , shape , weight_init , bias_init , resist_scale ,
1313 p_conn , batch_size = batch_size , ** kwargs )
@@ -17,6 +17,7 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
1717 self .w_bound = w_bound
1818 self .w_decay = w_decay ## synaptic decay
1919 self .eta = eta
20+ self .alpha = alpha
2021
2122 # compartments (state of the cell, parameters, will be updated through stateless calls)
2223 self .preVals = jnp .zeros ((self .batch_size , shape [0 ]))
@@ -29,19 +30,19 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
2930 self .dBiases = Compartment (jnp .zeros (shape [1 ]))
3031
3132 @staticmethod
32- def _compute_update (w_bound , pre , post , weights ):
33+ def _compute_update (alpha , pre , post , weights ):
3334 ## calculate synaptic update values
3435 dW = jnp .matmul (pre .T , post )
3536 db = jnp .sum (post , axis = 0 , keepdims = True )
36- # if w_bound > 0.:
37- # dW = dW * (w_bound - jnp.abs(weights))
37+ if alpha > 0. : ## apply synaptic dependency weighting
38+ dW = dW * (alpha - jnp .abs (weights ))
3839 return dW , db
3940
4041 @staticmethod
41- def _evolve (bias_init , eta , w_decay , w_bound , pre , post , preSpike ,
42+ def _evolve (bias_init , eta , alpha , w_decay , w_bound , pre , post , preSpike ,
4243 postSpike , weights , biases ):
4344 ## calculate synaptic update values
44- dWeights , dBiases = GatedHebbianSynapse ._compute_update (w_bound , pre , post , weights )
45+ dWeights , dBiases = GatedHebbianSynapse ._compute_update (alpha , pre , post , weights )
4546 weights = weights + dWeights * eta
4647 if w_decay > 0. :
4748 Wdec = jnp .matmul ((1. - preSpike ).T , postSpike ) * w_decay
0 commit comments