Skip to content

Commit f8aff77

Browse files
willgebhardtago109
andauthored
Dev (#62)
* implemented raw classical instantaneous stdp * mod to classical stdp * mod to classical stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * mod to stdp syn * minor mod of syn * cleaned up stdp syn * cleaned up stdp syn --------- Co-authored-by: ago109 <[email protected]>
1 parent cf639ad commit f8aff77

File tree

4 files changed

+221
-0
lines changed

4 files changed

+221
-0
lines changed

ngclearn/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .synapses.denseSynapse import DenseSynapse
2525
from .synapses.staticSynapse import StaticSynapse
2626
from .synapses.hebbian.hebbianSynapse import HebbianSynapse
27+
from .synapses.hebbian.STDPSynapse import STDPSynapse
2728
from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
2829
from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
2930
from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse

ngclearn/components/synapses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .STPDenseSynapse import STPDenseSynapse
55
## dense synaptic components
66
from .hebbian.hebbianSynapse import HebbianSynapse
7+
from .hebbian.STDPSynapse import STDPSynapse
78
from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
89
from .hebbian.expSTDPSynapse import ExpSTDPSynapse
910
from .hebbian.eventSTDPSynapse import EventSTDPSynapse
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngclearn import resolver, Component, Compartment
3+
from ngclearn.components.synapses import DenseSynapse
4+
from ngclearn.utils import tensorstats
5+
6+
class STDPSynapse(DenseSynapse): # power-law / trace-based STDP
7+
"""
8+
A synaptic cable that adjusts its efficacies via raw
9+
spike-timing-dependent plasticity (STDP).
10+
11+
| --- Synapse Compartments: ---
12+
| inputs - input (takes in external signals)
13+
| outputs - output signals (transformation induced by synapses)
14+
| weights - current value matrix of synaptic efficacies
15+
| key - JAX PRNG key
16+
| --- Synaptic Plasticity Compartments: ---
17+
| preSpike - pre-synaptic spike to drive long-term potentiation (takes in external signals)
18+
| postSpike - post-synaptic spike to drive long-term depression (takes in external signals)
19+
| pre_tols - pre-synaptic time-of-last-spike (takes in external signals)
20+
| post_tols - post-synaptic time-of-last-spike (takes in external signals)
21+
| dWeights - current delta matrix containing changes to be applied to synaptic efficacies
22+
| eta - global learning rate (multiplier beyond A_plus and A_minus)
23+
24+
| References:
25+
| Markram, Henry, et al. "Regulation of synaptic efficacy by coincidence of
26+
| postsynaptic APs and EPSPs." Science 275.5297 (1997): 213-215.
27+
|
28+
| Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modification by correlated
29+
| activity: Hebb's postulate revisited." Annual review of neuroscience 24.1
30+
| (2001): 139-166.
31+
32+
Args:
33+
name: the string name of this cell
34+
35+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
36+
with number of inputs by number of outputs)
37+
38+
A_plus: strength of long-term potentiation (LTP)
39+
40+
A_minus: strength of long-term depression (LTD)
41+
42+
tau_plus: time constant of long-term potentiation (LTP)
43+
44+
tau_minus: time constant of long-term depression (LTD)
45+
46+
eta: global learning rate initial value/condition (default: 1)
47+
48+
tau_w: time constant for synaptic adjustment; setting this to zero
49+
disables Euler-style synaptic adjustment (default: 0)
50+
51+
weight_init: a kernel to drive initialization of this synaptic cable's values;
52+
typically a tuple with 1st element as a string calling the name of
53+
initialization to use
54+
55+
resist_scale: a fixed scaling factor to apply to synaptic transform
56+
(Default: 1.), i.e., yields: out = ((W * Rscale) * in)
57+
58+
p_conn: probability of a connection existing (default: 1); setting
59+
this to < 1. will result in a sparser synaptic structure
60+
61+
w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1)
62+
"""
63+
64+
# Define Functions
65+
def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10., w_decay=0.,
66+
eta=1., tau_w=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1.,
67+
batch_size=1, **kwargs):
68+
super().__init__(name, shape, weight_init, None, resist_scale,
69+
p_conn, batch_size=batch_size, **kwargs)
70+
assert self.batch_size == 1 ## note: STDP only supports online learning in this implementation
71+
## Synaptic hyper-parameters
72+
self.shape = shape ## shape of synaptic efficacy matrix
73+
self.Aplus = A_plus ## LTP strength
74+
self.Aminus = A_minus ## LTD strength
75+
self.tau_plus = tau_plus ## LTP time constant
76+
self.tau_minus = tau_minus ## LTD time constant
77+
self.Rscale = resist_scale ## post-transformation scale factor
78+
self.w_bound = w_bound #1. ## soft weight constraint
79+
self.tau_w = tau_w ## synaptic update time constant
80+
self.w_decay = w_decay
81+
82+
## Compartment setup
83+
preVals = jnp.zeros((self.batch_size, shape[0]))
84+
postVals = jnp.zeros((self.batch_size, shape[1]))
85+
self.preSpike = Compartment(preVals)
86+
self.postSpike = Compartment(postVals)
87+
self.pre_tols = Compartment(preVals) ## pre-synaptic time-of-last-spike
88+
self.post_tols = Compartment(postVals) ## post-synaptic time-of-last-spike
89+
self.dWeights = Compartment(self.weights.value * 0)
90+
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate
91+
92+
@staticmethod
93+
def _compute_update(Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike,
94+
pre_tols, post_tols, weights):
95+
## calculate time deltas matrix block --> (t_post - t_pre)
96+
post_m = (post_tols > 0.) ## zero post-tols mask
97+
pre_m = (pre_tols > 0.).T ## zero pre-tols mask
98+
t_delta = ((weights * 0 + 1.) * post_tols) - pre_tols.T ## t_delta.shape = weights.shape
99+
t_delta = t_delta * post_m * pre_m ## mask out zero tols and same-time spikes
100+
pos_t_delta_m = (t_delta > 0.) ## positive t-delta mask
101+
neg_t_delta_m = (t_delta < 0.) ## negative t-delta mask
102+
#t_delta = t_delta * pos_t_delta_m + t_delta * neg_t_delta_m ## mask out same time spikes
103+
## calculate post-synaptic term
104+
postTerm = jnp.exp(-t_delta/tau_plus) * pos_t_delta_m
105+
dWpost = postTerm * (postSpike * Aplus)
106+
dWpre = 0.
107+
if Aminus > 0.:
108+
## calculate pre-synaptic term
109+
preTerm = jnp.exp(-t_delta / tau_minus) * neg_t_delta_m
110+
dWpre = -preTerm * (preSpike.T * Aminus)
111+
## calc final weighted adjustment
112+
dW = (dWpost + dWpre)
113+
return dW
114+
115+
@staticmethod
116+
def _evolve(dt, w_bound, w_decay, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike,
117+
postSpike, pre_tols, post_tols, weights, eta):
118+
dWeights = STDPSynapse._compute_update(
119+
Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike, pre_tols,
120+
post_tols, weights
121+
)
122+
## shift/alter values of synaptic efficacies
123+
if tau_w > 0.: ## triggers Euler-style synaptic update
124+
weights = weights + (-weights * dt/tau_w + dWeights * eta)
125+
else: ## raw simple ascent-style update
126+
weights = weights + dWeights * eta - weights * w_decay
127+
## enforce non-negativity
128+
eps = 0.001 # 0.01
129+
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
130+
return weights, dWeights
131+
132+
@resolver(_evolve)
133+
def evolve(self, weights, dWeights):
134+
self.weights.set(weights)
135+
self.dWeights.set(dWeights)
136+
137+
@staticmethod
138+
def _reset(batch_size, shape):
139+
preVals = jnp.zeros((batch_size, shape[0]))
140+
postVals = jnp.zeros((batch_size, shape[1]))
141+
inputs = preVals
142+
outputs = postVals
143+
preSpike = preVals
144+
postSpike = postVals
145+
pre_tols = preVals
146+
post_tols = postVals
147+
dWeights = jnp.zeros(shape)
148+
return inputs, outputs, preSpike, postSpike, pre_tols, post_tols, dWeights
149+
150+
@resolver(_reset)
151+
def reset(self, inputs, outputs, preSpike, postSpike, pre_tols, post_tols, dWeights):
152+
self.inputs.set(inputs)
153+
self.outputs.set(outputs)
154+
self.preSpike.set(preSpike)
155+
self.postSpike.set(postSpike)
156+
self.pre_tols.set(pre_tols)
157+
self.post_tols.set(post_tols)
158+
self.dWeights.set(dWeights)
159+
160+
@classmethod
161+
def help(cls): ## component help function
162+
properties = {
163+
"synapse_type": "STDPSynapse - performs an adaptable synaptic "
164+
"transformation of inputs to produce output signals; "
165+
"synapses are adjusted with classical "
166+
"spike-timing-dependent plasticity (STDP)"
167+
}
168+
compartment_props = {
169+
"inputs":
170+
{"inputs": "Takes in external input signal values",
171+
"preSpike": "Pre-synaptic spike compartment event for STDP (s_j)",
172+
"postSpike": "Post-synaptic spike compartment event for STDP (s_i)",
173+
"pre_tols": "Pre-synaptic time-of-last-spike (t_j)",
174+
"post_tols": "Post-synaptic time-of-last-spike (t_i)"},
175+
"states":
176+
{"weights": "Synapse efficacy/strength parameter values",
177+
"biases": "Base-rate/bias parameter values",
178+
"eta": "Global learning rate (multiplier beyond A_plus and A_minus)",
179+
"key": "JAX PRNG key"},
180+
"analytics":
181+
{"dWeights": "Synaptic weight value adjustment matrix produced at time t"},
182+
"outputs":
183+
{"outputs": "Output of synaptic transformation"},
184+
}
185+
hyperparams = {
186+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
187+
"batch_size": "Batch size dimension of this component",
188+
"weight_init": "Initialization conditions for synaptic weight (W) values",
189+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
190+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
191+
"A_plus": "Strength of long-term potentiation (LTP)",
192+
"A_minus": "Strength of long-term depression (LTD)",
193+
"tau_plus": "Time constant for long-term potentiation (LTP)",
194+
"tau_minus": "Time constant for long-term depression (LTD)",
195+
"eta": "Global learning rate initial condition",
196+
"tau_w": "Time constant for synaptic adjustment (if Euler-style change used)"
197+
}
198+
info = {cls.__name__: properties,
199+
"compartments": compartment_props,
200+
"dynamics": "outputs = [(W * Rscale) * inputs] ;"
201+
"dW_{ij}/dt = A_plus * exp(-(t_i - t_j)/tau_plus) * s_j -"
202+
" A_minus exp(-(t_i - t_j)/tau_minus) * s_i",
203+
"hyperparameters": hyperparams}
204+
return info
205+
206+
def __repr__(self):
207+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
208+
maxlen = max(len(c) for c in comps) + 5
209+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
210+
for c in comps:
211+
stats = tensorstats(getattr(self, c).value)
212+
if stats is not None:
213+
line = [f"{k}: {v}" for k, v in stats.items()]
214+
line = ", ".join(line)
215+
else:
216+
line = "None"
217+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
218+
return lines

ngclearn/components/synapses/hebbian/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .hebbianSynapse import HebbianSynapse
2+
from .STDPSynapse import STDPSynapse
23
from .traceSTDPSynapse import TraceSTDPSynapse
34
from .expSTDPSynapse import ExpSTDPSynapse
45
from .eventSTDPSynapse import EventSTDPSynapse

0 commit comments

Comments
 (0)