Skip to content

Commit 5f7c33b

Browse files
committed
added gated-hebb syn
1 parent ee4e401 commit 5f7c33b

File tree

4 files changed

+146
-0
lines changed

4 files changed

+146
-0
lines changed

ngclearn/components/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .synapses.denseSynapse import DenseSynapse
2525
from .synapses.staticSynapse import StaticSynapse
2626
from .synapses.hebbian.hebbianSynapse import HebbianSynapse
27+
from .synapses.hebbian.gatedHebbianSynapse import GatedHebbianSynapse
2728
from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
2829
from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
2930
from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse

ngclearn/components/synapses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .STPDenseSynapse import STPDenseSynapse
55
## dense synaptic components
66
from .hebbian.hebbianSynapse import HebbianSynapse
7+
from .hebbian.gatedHebbianSynapse import GatedHebbianSynapse
78
from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
89
from .hebbian.expSTDPSynapse import ExpSTDPSynapse
910
from .hebbian.eventSTDPSynapse import EventSTDPSynapse

ngclearn/components/synapses/hebbian/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .hebbianSynapse import HebbianSynapse
2+
from .gatedHebbianSynapse import GatedHebbianSynapse
23
from .traceSTDPSynapse import TraceSTDPSynapse
34
from .expSTDPSynapse import ExpSTDPSynapse
45
from .eventSTDPSynapse import EventSTDPSynapse
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngclearn import resolver, Component, Compartment
3+
from ngclearn.components.synapses import DenseSynapse
4+
from ngclearn.utils import tensorstats
5+
6+
class GatedHebbianSynapse(DenseSynapse):
7+
8+
# Define Functions
9+
def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
10+
w_bound=1., w_decay=0., p_conn=1., resist_scale=1.,
11+
batch_size=1, **kwargs):
12+
super().__init__(name, shape, weight_init, bias_init, resist_scale,
13+
p_conn, batch_size=batch_size, **kwargs)
14+
15+
## synaptic plasticity properties and characteristics
16+
self.shape = shape
17+
self.w_bound = w_bound
18+
self.w_decay = w_decay ## synaptic decay
19+
self.eta = eta
20+
21+
# compartments (state of the cell, parameters, will be updated through stateless calls)
22+
self.preVals = jnp.zeros((self.batch_size, shape[0]))
23+
self.postVals = jnp.zeros((self.batch_size, shape[1]))
24+
self.pre = Compartment(self.preVals)
25+
self.post = Compartment(self.postVals)
26+
self.preSpike = Compartment(self.preVals)
27+
self.postSpike = Compartment(self.postVals)
28+
self.dWeights = Compartment(jnp.zeros(shape))
29+
self.dBiases = Compartment(jnp.zeros(shape[1]))
30+
31+
@staticmethod
32+
def _compute_update(w_bound, pre, post, weights):
33+
## calculate synaptic update values
34+
dW = jnp.matmul(pre.T, post)
35+
db = jnp.sum(post, axis=0, keepdims=True)
36+
# if w_bound > 0.:
37+
# dW = dW * (w_bound - jnp.abs(weights))
38+
return dW, db
39+
40+
@staticmethod
41+
def _evolve(bias_init, eta, w_decay, w_bound, pre, post, weights, biases):
42+
## calculate synaptic update values
43+
dWeights, dBiases = GatedHebbianSynapse._compute_update(w_bound, pre, post, weights)
44+
weights = weights + dWeights * eta
45+
if bias_init != None:
46+
biases = biases + dBiases * eta
47+
if w_decay > 0.:
48+
Wdec = jnp.matmul((1. - pre).T, post) * w_decay
49+
weights = weights - Wdec
50+
weights = jnp.clip(weights, 0., w_bound)
51+
return weights, biases, dWeights, dBiases
52+
53+
@resolver(_evolve)
54+
def evolve(self, weights, biases, dWeights, dBiases):
55+
self.weights.set(weights)
56+
self.biases.set(biases)
57+
self.dWeights.set(dWeights)
58+
self.dBiases.set(dBiases)
59+
60+
@staticmethod
61+
def _reset(batch_size, shape):
62+
preVals = jnp.zeros((batch_size, shape[0]))
63+
postVals = jnp.zeros((batch_size, shape[1]))
64+
return (
65+
preVals, # inputs
66+
postVals, # outputs
67+
preVals, # pre
68+
postVals, # post
69+
preVals, # pre
70+
postVals, # post
71+
jnp.zeros(shape), # dW
72+
jnp.zeros(shape[1]), # db
73+
)
74+
75+
@resolver(_reset)
76+
def reset(self, inputs, outputs, pre, post, preSpike, postSpike, dWeights, dBiases):
77+
self.inputs.set(inputs)
78+
self.outputs.set(outputs)
79+
self.pre.set(pre)
80+
self.post.set(post)
81+
self.preSpike.set(preSpike)
82+
self.postSpike.set(postSpike)
83+
self.dWeights.set(dWeights)
84+
self.dBiases.set(dBiases)
85+
86+
@classmethod
87+
def help(cls): ## component help function
88+
properties = {
89+
"synapse_type": "HebbianSynapse - performs an adaptable synaptic "
90+
"transformation of inputs to produce output signals; "
91+
"synapses are adjusted via two-term/factor Hebbian adjustment"
92+
}
93+
compartment_props = {
94+
"inputs":
95+
{"inputs": "Takes in external input signal values",
96+
"pre": "Pre-synaptic statistic for Hebb rule (z_j)",
97+
"post": "Post-synaptic statistic for Hebb rule (z_i)"},
98+
"states":
99+
{"weights": "Synapse efficacy/strength parameter values",
100+
"biases": "Base-rate/bias parameter values",
101+
"key": "JAX PRNG key"},
102+
"analytics":
103+
{"dWeights": "Synaptic weight value adjustment matrix produced at time t",
104+
"dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"},
105+
"outputs":
106+
{"outputs": "Output of synaptic transformation"},
107+
}
108+
hyperparams = {
109+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
110+
"batch_size": "Batch size dimension of this component",
111+
"weight_init": "Initialization conditions for synaptic weight (W) values",
112+
"bias_init": "Initialization conditions for bias/base-rate (b) values",
113+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
114+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
115+
"is_nonnegative": "Should synapses be constrained to be non-negative post-updates?",
116+
"sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0",
117+
"eta": "Global (fixed) learning rate",
118+
"pre_wght": "Pre-synaptic weighting coefficient (q_pre)",
119+
"post_wght": "Post-synaptic weighting coefficient (q_post)",
120+
"w_bound": "Soft synaptic bound applied to synapses post-update",
121+
"w_decay": "Synaptic decay term",
122+
"optim_type": "Choice of optimizer to adjust synaptic weights"
123+
}
124+
info = {cls.__name__: properties,
125+
"compartments": compartment_props,
126+
"dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
127+
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay",
128+
"hyperparameters": hyperparams}
129+
return info
130+
131+
def __repr__(self):
132+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
133+
maxlen = max(len(c) for c in comps) + 5
134+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
135+
for c in comps:
136+
stats = tensorstats(getattr(self, c).value)
137+
if stats is not None:
138+
line = [f"{k}: {v}" for k, v in stats.items()]
139+
line = ", ".join(line)
140+
else:
141+
line = "None"
142+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
143+
return lines

0 commit comments

Comments
 (0)