Skip to content

Commit cf53968

Browse files
authored
feat NGC module regression (#86)
* feat npc module regression * Update __init__.py * Update __init__.py * Update elastic_net.py * Update lasso.py * Update ridge.py * Update elastic_net.py * Update ridge.py * Update lasso.py
1 parent eeb057a commit cf53968

File tree

5 files changed

+480
-0
lines changed

5 files changed

+480
-0
lines changed

ngclearn/modules/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from regression.elastic_net import Iterative_ElasticNet
2+
from regression.lasso import Iterative_Lasso
3+
from regression.ridge import Iterative_Ridge
4+
5+
6+
7+
8+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from elastic_net import Iterative_ElasticNet
2+
from lasso import Iterative_Lasso
3+
from ridge import Iterative_Ridge
4+
5+
6+
7+
8+
9+
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from jax import random, jit
2+
import numpy as np
3+
from ngclearn.utils import weight_distribution as dist
4+
from ngclearn import Context, numpy as jnp
5+
from ngclearn.components import (RateCell,
6+
HebbianSynapse,
7+
GaussianErrorCell,
8+
StaticSynapse)
9+
from ngclearn.utils.model_utils import scanner
10+
11+
12+
class Iterative_ElasticNet():
13+
"""
14+
A neural circuit implementation of the iterative Elastic Net (L1 and L2) algorithm
15+
using Hebbian learning update rule.
16+
17+
The circuit implements sparse regression through Hebbian synapses with Elastic Net regularization.
18+
19+
The specific differential equation that characterizes this model is dW_reg (for adjusting W, given
20+
dW (the gradient of loss/energy function), it adds lmbda * dW_reg to the dW)
21+
22+
| dW_reg = (jnp.sign(W) * l1_ratio) + (W * (1-l1_ratio)/2)
23+
| dW/dt = dW + lmbda * dW_reg
24+
25+
26+
27+
| --- Circuit Components: ---
28+
| W - HebbianSynapse for learning regularized dictionary weights
29+
| err - GaussianErrorCell for computing prediction errors
30+
| --- Component Compartments ---
31+
| W.inputs - input features (takes in external signals)
32+
| W.pre - pre-synaptic activity for Hebbian learning
33+
| W.post - post-synaptic error signals
34+
| W.weights - learned dictionary coefficients
35+
| err.mu - predicted outputs
36+
| err.target - target signals (target vector)
37+
| err.dmu - error gradients
38+
| err.L - loss/energy values
39+
40+
Args:
41+
key: JAX PRNG key for random number generation
42+
43+
name: string name for this solver
44+
45+
sys_dim: dimensionality of the system/target space
46+
47+
dict_dim: dimensionality of the dictionary/feature space/the number of predictors
48+
49+
batch_size: number of samples to process in parallel
50+
51+
weight_fill: initial constant value to fill weight matrix with (Default: 0.05)
52+
53+
lr: learning rate for synaptic weight updates (Default: 0.01)
54+
55+
lmbda: elastic net regularization lambda parameter (Default: 0.0001)
56+
57+
optim_type: optimization type for updating weights; supported values are
58+
"sgd" and "adam" (Default: "adam")
59+
60+
threshold: minimum absolute coefficient value - values below this are set
61+
to zero during thresholding (Default: 0.001)
62+
63+
epochs: number of training epochs (Default: 100)
64+
"""
65+
def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, lr=0.01,
66+
lmbda = 0.0001, l1_ratio=0.5, optim_type="adam", threshold=0.05, epochs=100):
67+
key, *subkeys = random.split(key, 10)
68+
69+
## synaptic plasticity properties and characteristics
70+
self.T = 100
71+
self.dt = 1
72+
self.epochs = epochs
73+
self.weight_fill = weight_fill
74+
self.threshold = threshold
75+
self.name = name
76+
feature_dim = dict_dim
77+
78+
with Context(self.name) as self.circuit:
79+
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr,
80+
sign_value=-1, weight_init=dist.constant(weight_fill),
81+
prior=('elastic_net', (lmbda, l1_ratio)), optim_type=optim_type, key=subkeys[0])
82+
self.err = GaussianErrorCell("err", n_units=sys_dim)
83+
84+
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
85+
self.W.batch_size = batch_size
86+
self.err.batch_size = batch_size
87+
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88+
self.err.mu << self.W.outputs
89+
self.W.post << self.err.dmu
90+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91+
advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
92+
self.err, ## finally, execute error neurons
93+
compile_key="advance_state")
94+
evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
95+
reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
96+
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
97+
self.dynamic()
98+
99+
def dynamic(self): ## create dynamic commands forself.circuit
100+
W, err = self.circuit.get_components("W", "err")
101+
self.self = W
102+
self.err = err
103+
104+
@Context.dynamicCommand
105+
def batch_set(batch_size):
106+
self.W.batch_size = batch_size
107+
self.err.batch_size = batch_size
108+
109+
@Context.dynamicCommand
110+
def clamps(y_scaled, X):
111+
self.W.inputs.set(X)
112+
self.W.pre.set(X)
113+
self.err.target.set(y_scaled)
114+
115+
self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
116+
self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
117+
self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
118+
119+
120+
@scanner
121+
def _process(compartment_values, args):
122+
_t, _dt = args
123+
compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
124+
return compartment_values, compartment_values[self.W.weights.path]
125+
126+
127+
def thresholding(self, scale=1.):
128+
coef_old = self.coef_
129+
new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.)
130+
131+
self.coef_ = new_coeff * scale
132+
self.W.weights.set(new_coeff)
133+
134+
return self.coef_, coef_old
135+
136+
137+
def fit(self, y, X):
138+
139+
self.circuit.reset()
140+
self.circuit.clamps(y_scaled=y, X=X)
141+
142+
for i in range(self.epochs):
143+
self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
144+
self.circuit.evolve(t=self.T, dt=self.dt)
145+
146+
self.coef_ = np.array(self.W.weights.value)
147+
148+
return self.coef_, self.err.mu.value, self.err.L.value
149+
150+
151+
152+
153+

ngclearn/modules/regression/lasso.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import jax
2+
import pandas as pd
3+
from jax import random, jit
4+
import numpy as np
5+
from scipy.integrate import solve_ivp
6+
import matplotlib.pyplot as plt
7+
from ngcsimlib.utils import Get_Compartment_Batch
8+
from ngclearn.utils.model_utils import normalize_matrix
9+
from ngclearn.utils import weight_distribution as dist
10+
from ngclearn import Context, numpy as jnp
11+
from ngclearn.components import (RateCell,
12+
HebbianSynapse,
13+
GaussianErrorCell,
14+
StaticSynapse)
15+
from ngclearn.utils.model_utils import scanner
16+
17+
18+
class Iterative_Lasso():
19+
"""
20+
A neural circuit implementation of the iterative Lasso (L1) algorithm
21+
using Hebbian learning update rule.
22+
23+
The circuit implements sparse coding through Hebbian synapses with L1 regularization.
24+
25+
The specific differential equation that characterizes this model is adding lmbda * sign(W)
26+
to the dW (the gradient of loss/energy function):
27+
| dW/dt = dW + lmbda * sign(W)
28+
29+
| --- Circuit Components: ---
30+
| W - HebbianSynapse for learning sparse dictionary weights
31+
| err - GaussianErrorCell for computing prediction errors
32+
| --- Component Compartments ---
33+
| W.inputs - input features (takes in external signals)
34+
| W.pre - pre-synaptic activity for Hebbian learning
35+
| W.post - post-synaptic error signals
36+
| W.weights - learned dictionary coefficients
37+
| err.mu - predicted outputs
38+
| err.target - target signals (target vector)
39+
| err.dmu - error gradients
40+
| err.L - loss/energy values
41+
42+
Args:
43+
key: JAX PRNG key for random number generation
44+
45+
name: string name for this solver
46+
47+
sys_dim: dimensionality of the system/target space
48+
49+
dict_dim: dimensionality of the dictionary/feature space/the number of predictors
50+
51+
batch_size: number of samples to process in parallel
52+
53+
weight_fill: initial constant value to fill weight matrix with (Default: 0.05)
54+
55+
lr: learning rate for synaptic weight updates (Default: 0.01)
56+
57+
lasso_lmbda: L1 regularization lambda parameter (Default: 0.0001)
58+
59+
optim_type: optimization type for updating weights; supported values are
60+
"sgd" and "adam" (Default: "adam")
61+
62+
threshold: minimum absolute coefficient value - values below this are set
63+
to zero during thresholding (Default: 0.001)
64+
65+
epochs: number of training epochs (Default: 100)
66+
"""
67+
68+
# Define Functions
69+
def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, lr=0.01,
70+
lasso_lmbda=0.0001, optim_type="adam", threshold=0.001, epochs=100):
71+
key, *subkeys = random.split(key, 10)
72+
73+
self.T = 100
74+
self.dt = 1
75+
self.epochs = epochs
76+
self.weight_fill = weight_fill
77+
self.threshold = threshold
78+
self.name = name
79+
feature_dim = dict_dim
80+
81+
with Context(self.name) as self.circuit:
82+
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr,
83+
sign_value=-1, weight_init=dist.constant(weight_fill),
84+
prior=('lasso', lasso_lmbda),
85+
optim_type=optim_type, key=subkeys[0])
86+
self.err = GaussianErrorCell("err", n_units=sys_dim)
87+
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88+
self.W.batch_size = batch_size
89+
self.err.batch_size = batch_size
90+
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91+
self.err.mu << self.W.outputs
92+
self.W.post << self.err.dmu
93+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
94+
advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
95+
self.err, ## finally, execute error neurons
96+
compile_key="advance_state")
97+
evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
98+
reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
99+
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
100+
self.dynamic()
101+
102+
def dynamic(self): ## create dynamic commands for self.circuit
103+
W, err = self.circuit.get_components("W", "err")
104+
self.self = W
105+
self.err = err
106+
107+
@Context.dynamicCommand
108+
def batch_set(batch_size):
109+
self.W.batch_size = batch_size
110+
self.err.batch_size = batch_size
111+
112+
@Context.dynamicCommand
113+
def clamps(y_scaled, X):
114+
self.W.inputs.set(X)
115+
self.W.pre.set(X)
116+
self.err.target.set(y_scaled)
117+
118+
self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
119+
self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
120+
self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
121+
122+
@scanner
123+
def _process(compartment_values, args):
124+
_t, _dt = args
125+
compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
126+
return compartment_values, compartment_values[self.W.weights.path]
127+
128+
129+
def thresholding(self, scale=2):
130+
coef_old = self.coef_
131+
new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.)
132+
133+
self.coef_ = new_coeff * scale
134+
self.W.weights.set(new_coeff)
135+
136+
return self.coef_, coef_old
137+
138+
139+
def fit(self, y, X):
140+
141+
self.circuit.reset()
142+
self.circuit.clamps(y_scaled=y, X=X)
143+
144+
for i in range(self.epochs):
145+
self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
146+
self.circuit.evolve(t=self.T, dt=self.dt)
147+
148+
self.coef_ = np.array(self.W.weights.value)
149+
150+
return self.coef_, self.err.mu.value, self.err.L.value
151+
152+
153+
154+
155+
156+
157+

0 commit comments

Comments
 (0)