|
| 1 | +import jax |
| 2 | +import pandas as pd |
| 3 | +from jax import random, jit |
| 4 | +import numpy as np |
| 5 | +from scipy.integrate import solve_ivp |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +from ngcsimlib.utils import Get_Compartment_Batch |
| 8 | +from ngclearn.utils.model_utils import normalize_matrix |
| 9 | +from ngclearn.utils import weight_distribution as dist |
| 10 | +from ngclearn import Context, numpy as jnp |
| 11 | +from ngclearn.components import (RateCell, |
| 12 | + HebbianSynapse, |
| 13 | + GaussianErrorCell, |
| 14 | + StaticSynapse) |
| 15 | +from ngclearn.utils.model_utils import scanner |
| 16 | + |
| 17 | + |
| 18 | +class Iterative_Lasso(): |
| 19 | + """ |
| 20 | + A neural circuit implementation of the iterative Lasso (L1) algorithm |
| 21 | + using Hebbian learning update rule. |
| 22 | +
|
| 23 | + The circuit implements sparse coding through Hebbian synapses with L1 regularization. |
| 24 | +
|
| 25 | + The specific differential equation that characterizes this model is adding lmbda * sign(W) |
| 26 | + to the dW (the gradient of loss/energy function): |
| 27 | + | dW/dt = dW + lmbda * sign(W) |
| 28 | +
|
| 29 | + | --- Circuit Components: --- |
| 30 | + | W - HebbianSynapse for learning sparse dictionary weights |
| 31 | + | err - GaussianErrorCell for computing prediction errors |
| 32 | + | --- Component Compartments --- |
| 33 | + | W.inputs - input features (takes in external signals) |
| 34 | + | W.pre - pre-synaptic activity for Hebbian learning |
| 35 | + | W.post - post-synaptic error signals |
| 36 | + | W.weights - learned dictionary coefficients |
| 37 | + | err.mu - predicted outputs |
| 38 | + | err.target - target signals (target vector) |
| 39 | + | err.dmu - error gradients |
| 40 | + | err.L - loss/energy values |
| 41 | +
|
| 42 | + Args: |
| 43 | + key: JAX PRNG key for random number generation |
| 44 | +
|
| 45 | + name: string name for this solver |
| 46 | +
|
| 47 | + sys_dim: dimensionality of the system/target space |
| 48 | +
|
| 49 | + dict_dim: dimensionality of the dictionary/feature space/the number of predictors |
| 50 | +
|
| 51 | + batch_size: number of samples to process in parallel |
| 52 | +
|
| 53 | + weight_fill: initial constant value to fill weight matrix with (Default: 0.05) |
| 54 | +
|
| 55 | + lr: learning rate for synaptic weight updates (Default: 0.01) |
| 56 | +
|
| 57 | + lasso_lmbda: L1 regularization lambda parameter (Default: 0.0001) |
| 58 | +
|
| 59 | + optim_type: optimization type for updating weights; supported values are |
| 60 | + "sgd" and "adam" (Default: "adam") |
| 61 | +
|
| 62 | + threshold: minimum absolute coefficient value - values below this are set |
| 63 | + to zero during thresholding (Default: 0.001) |
| 64 | +
|
| 65 | + epochs: number of training epochs (Default: 100) |
| 66 | + """ |
| 67 | + |
| 68 | + # Define Functions |
| 69 | + def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, lr=0.01, |
| 70 | + lasso_lmbda=0.0001, optim_type="adam", threshold=0.001, epochs=100): |
| 71 | + key, *subkeys = random.split(key, 10) |
| 72 | + |
| 73 | + self.T = 100 |
| 74 | + self.dt = 1 |
| 75 | + self.epochs = epochs |
| 76 | + self.weight_fill = weight_fill |
| 77 | + self.threshold = threshold |
| 78 | + self.name = name |
| 79 | + feature_dim = dict_dim |
| 80 | + |
| 81 | + with Context(self.name) as self.circuit: |
| 82 | + self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr, |
| 83 | + sign_value=-1, weight_init=dist.constant(weight_fill), |
| 84 | + prior=('lasso', lasso_lmbda), |
| 85 | + optim_type=optim_type, key=subkeys[0]) |
| 86 | + self.err = GaussianErrorCell("err", n_units=sys_dim) |
| 87 | + # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 88 | + self.W.batch_size = batch_size |
| 89 | + self.err.batch_size = batch_size |
| 90 | + # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 91 | + self.err.mu << self.W.outputs |
| 92 | + self.W.post << self.err.dmu |
| 93 | + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 94 | + advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses |
| 95 | + self.err, ## finally, execute error neurons |
| 96 | + compile_key="advance_state") |
| 97 | + evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve") |
| 98 | + reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset") |
| 99 | + # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 100 | + self.dynamic() |
| 101 | + |
| 102 | + def dynamic(self): ## create dynamic commands for self.circuit |
| 103 | + W, err = self.circuit.get_components("W", "err") |
| 104 | + self.self = W |
| 105 | + self.err = err |
| 106 | + |
| 107 | + @Context.dynamicCommand |
| 108 | + def batch_set(batch_size): |
| 109 | + self.W.batch_size = batch_size |
| 110 | + self.err.batch_size = batch_size |
| 111 | + |
| 112 | + @Context.dynamicCommand |
| 113 | + def clamps(y_scaled, X): |
| 114 | + self.W.inputs.set(X) |
| 115 | + self.W.pre.set(X) |
| 116 | + self.err.target.set(y_scaled) |
| 117 | + |
| 118 | + self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve") |
| 119 | + self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance") |
| 120 | + self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset") |
| 121 | + |
| 122 | + @scanner |
| 123 | + def _process(compartment_values, args): |
| 124 | + _t, _dt = args |
| 125 | + compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt) |
| 126 | + return compartment_values, compartment_values[self.W.weights.path] |
| 127 | + |
| 128 | + |
| 129 | + def thresholding(self, scale=2): |
| 130 | + coef_old = self.coef_ |
| 131 | + new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.) |
| 132 | + |
| 133 | + self.coef_ = new_coeff * scale |
| 134 | + self.W.weights.set(new_coeff) |
| 135 | + |
| 136 | + return self.coef_, coef_old |
| 137 | + |
| 138 | + |
| 139 | + def fit(self, y, X): |
| 140 | + |
| 141 | + self.circuit.reset() |
| 142 | + self.circuit.clamps(y_scaled=y, X=X) |
| 143 | + |
| 144 | + for i in range(self.epochs): |
| 145 | + self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)])) |
| 146 | + self.circuit.evolve(t=self.T, dt=self.dt) |
| 147 | + |
| 148 | + self.coef_ = np.array(self.W.weights.value) |
| 149 | + |
| 150 | + return self.coef_, self.err.mu.value, self.err.L.value |
| 151 | + |
| 152 | + |
| 153 | + |
| 154 | + |
| 155 | + |
| 156 | + |
| 157 | + |
0 commit comments