@@ -7,7 +7,7 @@ class GatedHebbianSynapse(DenseSynapse):
7
7
8
8
# Define Functions
9
9
def __init__ (self , name , shape , eta = 0. , weight_init = None , bias_init = None ,
10
- w_bound = 1. , w_decay = 0. , p_conn = 1. , resist_scale = 1. ,
10
+ w_bound = 1. , w_decay = 0. , alpha = 0. , p_conn = 1. , resist_scale = 1. ,
11
11
batch_size = 1 , ** kwargs ):
12
12
super ().__init__ (name , shape , weight_init , bias_init , resist_scale ,
13
13
p_conn , batch_size = batch_size , ** kwargs )
@@ -17,6 +17,7 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
17
17
self .w_bound = w_bound
18
18
self .w_decay = w_decay ## synaptic decay
19
19
self .eta = eta
20
+ self .alpha = alpha
20
21
21
22
# compartments (state of the cell, parameters, will be updated through stateless calls)
22
23
self .preVals = jnp .zeros ((self .batch_size , shape [0 ]))
@@ -29,19 +30,19 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
29
30
self .dBiases = Compartment (jnp .zeros (shape [1 ]))
30
31
31
32
@staticmethod
32
- def _compute_update (w_bound , pre , post , weights ):
33
+ def _compute_update (alpha , pre , post , weights ):
33
34
## calculate synaptic update values
34
35
dW = jnp .matmul (pre .T , post )
35
36
db = jnp .sum (post , axis = 0 , keepdims = True )
36
- # if w_bound > 0.:
37
- # dW = dW * (w_bound - jnp.abs(weights))
37
+ if alpha > 0. : ## apply synaptic dependency weighting
38
+ dW = dW * (alpha - jnp .abs (weights ))
38
39
return dW , db
39
40
40
41
@staticmethod
41
- def _evolve (bias_init , eta , w_decay , w_bound , pre , post , preSpike ,
42
+ def _evolve (bias_init , eta , alpha , w_decay , w_bound , pre , post , preSpike ,
42
43
postSpike , weights , biases ):
43
44
## calculate synaptic update values
44
- dWeights , dBiases = GatedHebbianSynapse ._compute_update (w_bound , pre , post , weights )
45
+ dWeights , dBiases = GatedHebbianSynapse ._compute_update (alpha , pre , post , weights )
45
46
weights = weights + dWeights * eta
46
47
if w_decay > 0. :
47
48
Wdec = jnp .matmul ((1. - preSpike ).T , postSpike ) * w_decay
0 commit comments