|
| 1 | +from jax import random, numpy as jnp, jit |
| 2 | +from ngclearn import resolver, Component, Compartment |
| 3 | +from ngclearn.components.synapses import DenseSynapse |
| 4 | +from ngclearn.utils import tensorstats |
| 5 | + |
| 6 | +class STDPSynapse(DenseSynapse): # power-law / trace-based STDP |
| 7 | + """ |
| 8 | + A synaptic cable that adjusts its efficacies via raw |
| 9 | + spike-timing-dependent plasticity (STDP). |
| 10 | +
|
| 11 | + | --- Synapse Compartments: --- |
| 12 | + | inputs - input (takes in external signals) |
| 13 | + | outputs - output signals (transformation induced by synapses) |
| 14 | + | weights - current value matrix of synaptic efficacies |
| 15 | + | key - JAX PRNG key |
| 16 | + | --- Synaptic Plasticity Compartments: --- |
| 17 | + | preSpike - pre-synaptic spike to drive long-term potentiation (takes in external signals) |
| 18 | + | postSpike - post-synaptic spike to drive long-term depression (takes in external signals) |
| 19 | + | pre_tols - pre-synaptic time-of-last-spike (takes in external signals) |
| 20 | + | post_tols - post-synaptic time-of-last-spike (takes in external signals) |
| 21 | + | dWeights - current delta matrix containing changes to be applied to synaptic efficacies |
| 22 | + | eta - global learning rate (multiplier beyond A_plus and A_minus) |
| 23 | +
|
| 24 | + | References: |
| 25 | + | Markram, Henry, et al. "Regulation of synaptic efficacy by coincidence of |
| 26 | + | postsynaptic APs and EPSPs." Science 275.5297 (1997): 213-215. |
| 27 | + | |
| 28 | + | Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modification by correlated |
| 29 | + | activity: Hebb's postulate revisited." Annual review of neuroscience 24.1 |
| 30 | + | (2001): 139-166. |
| 31 | +
|
| 32 | + Args: |
| 33 | + name: the string name of this cell |
| 34 | +
|
| 35 | + shape: tuple specifying shape of this synaptic cable (usually a 2-tuple |
| 36 | + with number of inputs by number of outputs) |
| 37 | +
|
| 38 | + A_plus: strength of long-term potentiation (LTP) |
| 39 | +
|
| 40 | + A_minus: strength of long-term depression (LTD) |
| 41 | +
|
| 42 | + tau_plus: time constant of long-term potentiation (LTP) |
| 43 | +
|
| 44 | + tau_minus: time constant of long-term depression (LTD) |
| 45 | +
|
| 46 | + eta: global learning rate initial value/condition (default: 1) |
| 47 | +
|
| 48 | + tau_w: time constant for synaptic adjustment; setting this to zero |
| 49 | + disables Euler-style synaptic adjustment (default: 0) |
| 50 | +
|
| 51 | + weight_init: a kernel to drive initialization of this synaptic cable's values; |
| 52 | + typically a tuple with 1st element as a string calling the name of |
| 53 | + initialization to use |
| 54 | +
|
| 55 | + resist_scale: a fixed scaling factor to apply to synaptic transform |
| 56 | + (Default: 1.), i.e., yields: out = ((W * Rscale) * in) |
| 57 | +
|
| 58 | + p_conn: probability of a connection existing (default: 1); setting |
| 59 | + this to < 1. will result in a sparser synaptic structure |
| 60 | +
|
| 61 | + w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1) |
| 62 | + """ |
| 63 | + |
| 64 | + # Define Functions |
| 65 | + def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10., w_decay=0., |
| 66 | + eta=1., tau_w=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., |
| 67 | + batch_size=1, **kwargs): |
| 68 | + super().__init__(name, shape, weight_init, None, resist_scale, |
| 69 | + p_conn, batch_size=batch_size, **kwargs) |
| 70 | + assert self.batch_size == 1 ## note: STDP only supports online learning in this implementation |
| 71 | + ## Synaptic hyper-parameters |
| 72 | + self.shape = shape ## shape of synaptic efficacy matrix |
| 73 | + self.Aplus = A_plus ## LTP strength |
| 74 | + self.Aminus = A_minus ## LTD strength |
| 75 | + self.tau_plus = tau_plus ## LTP time constant |
| 76 | + self.tau_minus = tau_minus ## LTD time constant |
| 77 | + self.Rscale = resist_scale ## post-transformation scale factor |
| 78 | + self.w_bound = w_bound #1. ## soft weight constraint |
| 79 | + self.tau_w = tau_w ## synaptic update time constant |
| 80 | + self.w_decay = w_decay |
| 81 | + |
| 82 | + ## Compartment setup |
| 83 | + preVals = jnp.zeros((self.batch_size, shape[0])) |
| 84 | + postVals = jnp.zeros((self.batch_size, shape[1])) |
| 85 | + self.preSpike = Compartment(preVals) |
| 86 | + self.postSpike = Compartment(postVals) |
| 87 | + self.pre_tols = Compartment(preVals) ## pre-synaptic time-of-last-spike |
| 88 | + self.post_tols = Compartment(postVals) ## post-synaptic time-of-last-spike |
| 89 | + self.dWeights = Compartment(self.weights.value * 0) |
| 90 | + self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate |
| 91 | + |
| 92 | + @staticmethod |
| 93 | + def _compute_update(Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike, |
| 94 | + pre_tols, post_tols, weights): |
| 95 | + ## calculate time deltas matrix block --> (t_post - t_pre) |
| 96 | + post_m = (post_tols > 0.) ## zero post-tols mask |
| 97 | + pre_m = (pre_tols > 0.).T ## zero pre-tols mask |
| 98 | + t_delta = ((weights * 0 + 1.) * post_tols) - pre_tols.T ## t_delta.shape = weights.shape |
| 99 | + t_delta = t_delta * post_m * pre_m ## mask out zero tols and same-time spikes |
| 100 | + pos_t_delta_m = (t_delta > 0.) ## positive t-delta mask |
| 101 | + neg_t_delta_m = (t_delta < 0.) ## negative t-delta mask |
| 102 | + #t_delta = t_delta * pos_t_delta_m + t_delta * neg_t_delta_m ## mask out same time spikes |
| 103 | + ## calculate post-synaptic term |
| 104 | + postTerm = jnp.exp(-t_delta/tau_plus) * pos_t_delta_m |
| 105 | + dWpost = postTerm * (postSpike * Aplus) |
| 106 | + dWpre = 0. |
| 107 | + if Aminus > 0.: |
| 108 | + ## calculate pre-synaptic term |
| 109 | + preTerm = jnp.exp(-t_delta / tau_minus) * neg_t_delta_m |
| 110 | + dWpre = -preTerm * (preSpike.T * Aminus) |
| 111 | + ## calc final weighted adjustment |
| 112 | + dW = (dWpost + dWpre) |
| 113 | + return dW |
| 114 | + |
| 115 | + @staticmethod |
| 116 | + def _evolve(dt, w_bound, w_decay, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike, |
| 117 | + postSpike, pre_tols, post_tols, weights, eta): |
| 118 | + dWeights = STDPSynapse._compute_update( |
| 119 | + Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike, pre_tols, |
| 120 | + post_tols, weights |
| 121 | + ) |
| 122 | + ## shift/alter values of synaptic efficacies |
| 123 | + if tau_w > 0.: ## triggers Euler-style synaptic update |
| 124 | + weights = weights + (-weights * dt/tau_w + dWeights * eta) |
| 125 | + else: ## raw simple ascent-style update |
| 126 | + weights = weights + dWeights * eta - weights * w_decay |
| 127 | + ## enforce non-negativity |
| 128 | + eps = 0.001 # 0.01 |
| 129 | + weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound)) |
| 130 | + return weights, dWeights |
| 131 | + |
| 132 | + @resolver(_evolve) |
| 133 | + def evolve(self, weights, dWeights): |
| 134 | + self.weights.set(weights) |
| 135 | + self.dWeights.set(dWeights) |
| 136 | + |
| 137 | + @staticmethod |
| 138 | + def _reset(batch_size, shape): |
| 139 | + preVals = jnp.zeros((batch_size, shape[0])) |
| 140 | + postVals = jnp.zeros((batch_size, shape[1])) |
| 141 | + inputs = preVals |
| 142 | + outputs = postVals |
| 143 | + preSpike = preVals |
| 144 | + postSpike = postVals |
| 145 | + pre_tols = preVals |
| 146 | + post_tols = postVals |
| 147 | + dWeights = jnp.zeros(shape) |
| 148 | + return inputs, outputs, preSpike, postSpike, pre_tols, post_tols, dWeights |
| 149 | + |
| 150 | + @resolver(_reset) |
| 151 | + def reset(self, inputs, outputs, preSpike, postSpike, pre_tols, post_tols, dWeights): |
| 152 | + self.inputs.set(inputs) |
| 153 | + self.outputs.set(outputs) |
| 154 | + self.preSpike.set(preSpike) |
| 155 | + self.postSpike.set(postSpike) |
| 156 | + self.pre_tols.set(pre_tols) |
| 157 | + self.post_tols.set(post_tols) |
| 158 | + self.dWeights.set(dWeights) |
| 159 | + |
| 160 | + @classmethod |
| 161 | + def help(cls): ## component help function |
| 162 | + properties = { |
| 163 | + "synapse_type": "STDPSynapse - performs an adaptable synaptic " |
| 164 | + "transformation of inputs to produce output signals; " |
| 165 | + "synapses are adjusted with classical " |
| 166 | + "spike-timing-dependent plasticity (STDP)" |
| 167 | + } |
| 168 | + compartment_props = { |
| 169 | + "inputs": |
| 170 | + {"inputs": "Takes in external input signal values", |
| 171 | + "preSpike": "Pre-synaptic spike compartment event for STDP (s_j)", |
| 172 | + "postSpike": "Post-synaptic spike compartment event for STDP (s_i)", |
| 173 | + "pre_tols": "Pre-synaptic time-of-last-spike (t_j)", |
| 174 | + "post_tols": "Post-synaptic time-of-last-spike (t_i)"}, |
| 175 | + "states": |
| 176 | + {"weights": "Synapse efficacy/strength parameter values", |
| 177 | + "biases": "Base-rate/bias parameter values", |
| 178 | + "eta": "Global learning rate (multiplier beyond A_plus and A_minus)", |
| 179 | + "key": "JAX PRNG key"}, |
| 180 | + "analytics": |
| 181 | + {"dWeights": "Synaptic weight value adjustment matrix produced at time t"}, |
| 182 | + "outputs": |
| 183 | + {"outputs": "Output of synaptic transformation"}, |
| 184 | + } |
| 185 | + hyperparams = { |
| 186 | + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", |
| 187 | + "batch_size": "Batch size dimension of this component", |
| 188 | + "weight_init": "Initialization conditions for synaptic weight (W) values", |
| 189 | + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", |
| 190 | + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", |
| 191 | + "A_plus": "Strength of long-term potentiation (LTP)", |
| 192 | + "A_minus": "Strength of long-term depression (LTD)", |
| 193 | + "tau_plus": "Time constant for long-term potentiation (LTP)", |
| 194 | + "tau_minus": "Time constant for long-term depression (LTD)", |
| 195 | + "eta": "Global learning rate initial condition", |
| 196 | + "tau_w": "Time constant for synaptic adjustment (if Euler-style change used)" |
| 197 | + } |
| 198 | + info = {cls.__name__: properties, |
| 199 | + "compartments": compartment_props, |
| 200 | + "dynamics": "outputs = [(W * Rscale) * inputs] ;" |
| 201 | + "dW_{ij}/dt = A_plus * exp(-(t_i - t_j)/tau_plus) * s_j -" |
| 202 | + " A_minus exp(-(t_i - t_j)/tau_minus) * s_i", |
| 203 | + "hyperparameters": hyperparams} |
| 204 | + return info |
| 205 | + |
| 206 | + def __repr__(self): |
| 207 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 208 | + maxlen = max(len(c) for c in comps) + 5 |
| 209 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 210 | + for c in comps: |
| 211 | + stats = tensorstats(getattr(self, c).value) |
| 212 | + if stats is not None: |
| 213 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 214 | + line = ", ".join(line) |
| 215 | + else: |
| 216 | + line = "None" |
| 217 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 218 | + return lines |
0 commit comments