@@ -56,7 +56,7 @@ class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP
56
56
57
57
# Define Functions
58
58
def __init__ (self , name , shape , eta , lmbda = 0.01 , A_plus = 1. , A_minus = 1. ,
59
- presyn_win_len = 1 . , w_bound = 1. , weight_init = None , resist_scale = 1. ,
59
+ presyn_win_len = 2 . , w_bound = 1. , weight_init = None , resist_scale = 1. ,
60
60
p_conn = 1. , batch_size = 1 , ** kwargs ):
61
61
super ().__init__ (name , shape , weight_init , None , resist_scale , p_conn ,
62
62
batch_size = batch_size , ** kwargs )
@@ -65,6 +65,7 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
65
65
self .eta = eta ## global learning rate governing plasticity
66
66
self .lmbda = lmbda ## controls scaling of STDP rule
67
67
self .presyn_win_len = presyn_win_len
68
+ assert self .presyn_win_len >= 0. ## pre-synaptic window must be non-negative
68
69
self .Aplus = A_plus
69
70
self .Aminus = A_minus
70
71
self .Rscale = resist_scale ## post-transformation scale factor
@@ -82,10 +83,13 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
82
83
def _compute_update (t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols ,
83
84
postSpike , weights ):
84
85
## check if a spike occurred in window of (t - presyn_win_len, t]
85
- m = (pre_tols > 0. ) * 1. ## ignore default value of tols = 0 ms
86
- lbound = ((t - presyn_win_len ) < pre_tols ) * 1.
87
- rbound = (pre_tols <= t ) * 1.
88
- preSpike = lbound * rbound * m
86
+ m = (pre_tols > 0. ) * 1. ## ignore default value of tols = 0 ms
87
+ if presyn_win_len > 0. :
88
+ lbound = ((t - presyn_win_len ) < pre_tols ) * 1.
89
+ preSpike = lbound * m
90
+ else :
91
+ check_spike = (pre_tols == t ) * 1.
92
+ preSpike = check_spike * m
89
93
## this implements a generalization of the rule in eqn 18 of the paper
90
94
pos_shift = w_bound - (weights * (1. + lmbda ))
91
95
pos_shift = pos_shift * Aplus
0 commit comments