From bfd7d0d289e31915180f3b6e97349cff3e769728 Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Tue, 10 Dec 2024 05:53:00 -0500 Subject: [PATCH 1/2] add DeepMoD_PC --- exhibits/DeepMoD_PC/deepmod.py | 323 +++++++++++++++++++++++++++ exhibits/DeepMoD_PC/train_deepmod.py | 133 +++++++++++ 2 files changed, 456 insertions(+) create mode 100644 exhibits/DeepMoD_PC/deepmod.py create mode 100644 exhibits/DeepMoD_PC/train_deepmod.py diff --git a/exhibits/DeepMoD_PC/deepmod.py b/exhibits/DeepMoD_PC/deepmod.py new file mode 100644 index 0000000..13385d6 --- /dev/null +++ b/exhibits/DeepMoD_PC/deepmod.py @@ -0,0 +1,323 @@ +from jax import random, jit +import numpy as np +from ngclearn.utils.io_utils import makedir + +from ngclearn.utils import weight_distribution as dist +from ngclearn import Context, numpy as jnp +from ngclearn.components import (RateCell, + HebbianSynapse, + GaussianErrorCell, + StaticSynapse) +from ngclearn.utils.model_utils import scanner +from ngclearn.modules.regression.lasso import Iterative_Lasso as Lasso +from ngclearn.modules.regression.elastic_net import Iterative_ElasticNet as ElasticNet +from ngclearn.modules.regression.ridge import Iterative_Ridge as Ridge + + + +class DeepMoD(): + """ + Structure for constructing the Deep learning driven Model Discovery: + Both, Gert-Jan, Gijs Vermarien, and Remy Kusters. "Sparsely constrained + neural networks for model discovery of PDEs." arXiv preprint + arXiv:2011.04336 (2020). + + + Note this model decouples the network constraint of the differential + equation terms and the sparsity selection process, allowing for more flexible and + robust model discovery by first calculating a sparsity mask and then constraining + the network only with active terms. + + (The original paper was Deep learning driven Model Discovery (DeepMoD): + Both, Gert-Jan, et al. "DeepMoD: Deep learning for model discovery + in noisy data." Journal of Computational Physics 428 (2021): 109985.) + + + | Node Name Structure: + | z3 -(W3)-> e2, z2 -(W2)-> e1, z1 -(W1)-> e0; + | e2 -(E2)-> z2 <- e1, e1 -(E1)-> z1 <- e0 + | Note: W1, W2, W3 -> Hebbian-adapted synapses + + + Args: + dkey: JAX seeding key + + ts: Time series data points + + dict_dim: Dimensionality of the dictionary/library space + + lib_creator: Library creator function for creating candidate functions out of the predicted values (Xmu) + + in_dim: Input dimensionality + + h1_dim: Dimensionality of first hidden layer + + h2_dim: Dimensionality of second hidden layer + + out_dim: Output dimensionality + + batch_size: Number of samples to process in each batch + + w_fill: Initial weight fill value (Default: 0.05) + + lr: Learning rate for optimization (Default: 0.01) + + lmbda: Regularization parameter (Default: 0.0001) + + l1_ratio: Elastic net mixing parameter (Default: 0.0) + + optim_type: Type of optimizer to use (Default: "adam") + + threshold: Threshold for sparse coefficient selection (Default: 0.001) + + scale: Scaling factor for dictionary terms (Default: 2.0) + + solver_name: Type of regression solver ("lasso", "elastic_net", or "ridge") (Default: "lasso") + + eta: Learning rate for Hebbian updates (Default: 1e-3) + + tau_m: Membrane time constant (Default: 20.0) + + T: Number of discrete time steps for simulation (Default: 50) + + dt: Integration time step (Default: 1.0) + + exp_dir: Directory path for saving experimental results (Default: "exp") + + model_name: Name identifier for the model (Default: "deepmod") + + """ + def __init__(self, key, ts, dict_dim, lib_creator, in_dim, h1_dim, h2_dim, out_dim, batch_size, + w_fill=0.05, lr=0.01, lmbda=0.0001, l1_ratio=0., optim_type="adam", threshold=0.001, scale=2., + solver_name = "lasso", eta = 1e-3, tau_m = 20., T=50, dt=1., + model_name="deepmod", **kwargs): + dkey, *subkeys = random.split(key, 10) + + self.model_name = model_name + self.solver_name = solver_name + self.nodes = None + self.threshold = threshold + + ## meta-parameters for model dynamics + self.T = T + self.dt = dt + self.ts = ts + self.eta = eta + self.lib_creator = lib_creator + + if solver_name == "lasso" or solver_name == "l1": + print(" >> Building Lasso solver model...") + self.scale = scale + epochs = 100 + sys_dim = out_dim + self.method_params = (key, self.solver_name, sys_dim, dict_dim, batch_size, w_fill, lr, + lmbda, optim_type, threshold, epochs) + + self.solver = Lasso(*self.method_params) + self.W_init = self.solver.W.weights.value + + + if solver_name == "elastic_net" or solver_name == "l1l2": + print(" >> Building Elastic-Net solver model...") + + self.scale = scale + epochs = 100 + sys_dim = out_dim + self.method_params = (key, self.solver_name, sys_dim, dict_dim, batch_size, w_fill, lr, + lmbda, l1_ratio, optim_type, threshold, epochs) + + self.solver = ElasticNet(*self.method_params) + + + if solver_name == "ridge" or solver_name == "l2": + print(" >> Building Ridge solver model...") + self.scale = scale + epochs = 100 + sys_dim = out_dim + self.method_params = (key, self.solver_name, sys_dim, dict_dim, batch_size, w_fill, lr, + lmbda, optim_type, threshold, epochs) + + self.solver = Ridge(*self.method_params) + + opt_type = "adam" + act_fx = "sine" + self.omega_0 = 30 # check 2-300-10 + + W3_dist = dist.uniform( + amin=-1 / h2_dim, + amax=1 / h2_dim + ) + W2_dist = dist.uniform( + amin=-np.sqrt(6 / h1_dim) / self.omega_0, + amax=np.sqrt(6 / h1_dim) / self.omega_0 + ) + W1_dist = dist.uniform( + amin=-np.sqrt(6 / out_dim) / self.omega_0, + amax=np.sqrt(6 / out_dim) / self.omega_0 + ) + + with Context(self.model_name) as self.model: + ############ L3 + self.z3 = RateCell("z3", n_units=in_dim, tau_m=tau_m , act_fx="identity") + self.W3 = HebbianSynapse("W3", shape=(in_dim, h2_dim), eta=eta, w_bound=0., signVal=-1, sign_value=-1, + optim_type=opt_type, weight_init=W3_dist, key=subkeys[0] + ) + ############ L2 + self.e2 = GaussianErrorCell("e2", n_units=h2_dim) + self.z2 = RateCell("z2", n_units=h2_dim, tau_m=tau_m , act_fx=act_fx, omega_0=self.omega_0, + batch_size=batch_size) + + self.W2 = HebbianSynapse("W2", shape=(h2_dim, h1_dim), eta=eta, w_bound=0., signVal=-1, sign_value=-1, + optim_type=opt_type, weight_init=W2_dist, key=subkeys[1]) + self.E2 = StaticSynapse("E2", shape=(h1_dim, h2_dim) + ) + ############ L1 + self.e1 = GaussianErrorCell("e1", n_units=h1_dim) + self.z1 = RateCell("z1", n_units=h1_dim, tau_m=tau_m , act_fx="identity") + self.W1 = HebbianSynapse("W1", shape=(h1_dim, out_dim), eta=eta, w_bound=0., signVal=-1, sign_value=-1, + optim_type=opt_type, weight_init=W1_dist, key=subkeys[2]) + self.E1 = StaticSynapse("E1", shape=(out_dim, h1_dim) + ) + ############ input + self.e0 = GaussianErrorCell("e0", n_units=out_dim + ) + # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + self.z3.batch_size= batch_size + self.z2.batch_size= batch_size + self.z1.batch_size = batch_size + + self.e2.batch_size = batch_size + self.e1.batch_size = batch_size + self.e0.batch_size = batch_size + + self.W3.batch_size = batch_size + self.W2.batch_size = batch_size + self.W1.batch_size = batch_size + + self.E2.batch_size = batch_size + self.E1.batch_size = batch_size + # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + self.W3.inputs << self.z3.zF + self.e2.mu << self.W3.outputs + + self.e2.target << self.z2.z + self.W2.inputs << self.z2.zF + self.e1.mu << self.W2.outputs + + self.e1.target << self.z1.z + self.W1.inputs << self.z1.zF + self.e0.mu << self.W1.outputs + + self.z2.j_td << self.e2.dtarget + self.E2.inputs << self.e1.dmu + self.z2.j << self.E2.outputs + + self.z1.j_td << self.e1.dtarget + self.E1.inputs << self.e0.dmu + self.z1.j << self.E1.outputs + + self.W1.pre << self.z1.zF + self.W1.post << self.e0.dmu + + self.W2.pre << self.z2.zF + self.W2.post << self.e1.dmu + + self.W3.pre << self.z3.zF + self.W3.post << self.e2.dmu + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + advance_cmd, advance_args =self.model.compile_by_key(self.E2, self.E1, ## execute feedback first + self.z3, self.z2, self.z1, + self.W3, self.W2, self.W1, ## execute prediction synapses + self.e2, self.e1, self.e0, ## finally, execute error neurons + compile_key="advance_state", name='advance_state') + + evolve_cmd, evolve_args =self.model.compile_by_key(self.W1, self.W2, self.W3, + compile_key="evolve") + + reset_cmd, reset_args =self.model.compile_by_key(self.z3, self.z2, self.z1, + self.e2, self.e1, self.e0, + self.W3, self.W2, self.W1, + self.E1, self.E2, + compile_key="reset") + + # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + self.dynamic() + + def dynamic(self): ## create dynamic commands forself.circuit + z3, z2, z1, W3, W2, W1, E1, E2, e0, e1, e2 = self.model.get_components("z3", "z2", "z1", + "W3", "W2", "W1", + "E1", "E2", + "e0", "e1", "e2") + self.W1, self.W2, self.W3 = (W1, W2, W3) + self.e0, self.e1, self.e2 = (e0, e1, e2) + self.z1, self.z2, self.z3 = (z1, z2, z3) + self.E1, self.E2 = (E1, E2) + + @Context.dynamicCommand + def clamps(input, target): + self.z3.z.set(input) + self.e0.target.set(target) + + @Context.dynamicCommand + def batch_set(batch_size): + self.z3.batch_size= batch_size + self.z2.batch_size= batch_size + self.z1.batch_size = batch_size + + self.e2.batch_size = batch_size + self.e1.batch_size = batch_size + self.e0.batch_size = batch_size + + self.W3.batch_size = batch_size + self.W2.batch_size = batch_size + self.W1.batch_size = batch_size + + self.E2.batch_size = batch_size + self.E1.batch_size = batch_size + + self.model.wrap_and_add_command(jit(self.model.evolve), name="evolve") + # self.model.wrap_and_add_command(jit(self.model.advance_state), name="advance") + self.model.wrap_and_add_command(jit(self.model.reset), name="reset") + + @scanner + def _process(compartment_values, args): + _t, _dt = args + compartment_values = self.model.advance_state( + compartment_values, t=_t, dt=_dt) + return compartment_values, compartment_values[self.W1.outputs.path] + + + def prediction_process(self, input, target): + self.model.batch_set(len(input)) + self.E1.weights.set(self.W1.weights.value.T) + self.E2.weights.set(self.W2.weights.value.T) + + self.model.reset() + self.model.clamps(input, target) + + z_codes = self.model._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)])) + self.model.evolve(t=self.T, dt=self.dt) + + return self.e0.mu.value, self.e0.L.value + + + def thresholding(self): + coef_old = self.solver.W.weights.value + coef_new = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.) + + self.solver.W.weights.set(coef_new) + + return coef_new + + + def process(self, ts_scaled, X): + self.model.batch_set(len(ts_scaled)) + Xmu, loss = self.prediction_process(input=self.ts, target=X) + + library, _ = self.lib_creator.fit([Xmu[:, i] for i in range(Xmu.shape[1])]) + dXmu = jnp.array(np.gradient(jnp.array(Xmu), self.ts.ravel(), axis=0)) + + coef = self.solver.fit(y=dXmu/self.scale, X=library)[0] + + return coef, loss \ No newline at end of file diff --git a/exhibits/DeepMoD_PC/train_deepmod.py b/exhibits/DeepMoD_PC/train_deepmod.py new file mode 100644 index 0000000..3a21f41 --- /dev/null +++ b/exhibits/DeepMoD_PC/train_deepmod.py @@ -0,0 +1,133 @@ +import jax +from jax import random, jit +import numpy as np +from ngclearn import Context, numpy as jnp +# print(jax.__version__) +from deepmod import DeepMoD +from ngclearn.utils.feature_dictionaries.polynomialLibrary import PolynomialLibrary +from ngclearn.utils.diffeq.odes import cubic_2D, linear_2D, lorenz, oscillator, linear_3D +from ngclearn.utils.diffeq.ode_solver import solve_ode +# ------------------------------------- +np.set_printoptions(suppress=True, precision=3) + +key = random.PRNGKey(1234) +key_ = random.PRNGKey(3476) +# # ------------------------------------------- System Configs --------------------------------------------- +dfx = linear_2D +include_bias = False +eta = 0.01 + +if dfx == linear_2D: + x0 = jnp.array([3, -1.5], dtype=jnp.float32) + deg = 2 + threshold = 0.02 + T = 800 + prob = 0.3 + w_fill = 0.05 + lr = 0.01 + include_bias = False +elif dfx == linear_3D: + x0 = jnp.array([1, 1., -1], dtype=jnp.float32) + deg = 2 + threshold = 0.05 + T = 2000 + prob = 0.3 + lr = 0.01 + w_fill = 0.05 + include_bias = False +elif dfx == cubic_2D: + x0 = jnp.array([2., 0.], dtype=jnp.float32) + deg = 3 + threshold = 0.05 + T = 1000 + # scale = 4 #scale = (dX.max() - dX.min()) / 4 # / 2 = / (max - min) = / (1 - (-1)) + w_fill = 0.05 + lr = 0.01 + prob = 0.3 + # p = prob/scale # 0.05 + # inter = 1 + prob # 1.3 / scale == 1 / scale ===> + include_bias = False +elif dfx == lorenz: + x0 = jnp.array([-8, 8, 27], dtype=jnp.float32) + threshold = 0.5 + deg = 2 + T = 1000 + prob = 0.3 + eta = 0.02 + w_fill = 0.05 + lr = 0.002 + include_bias = False + # scale = 4 + +# ---------------------------------- Solving System (for Data generation) ------------------------------------------ +n_epochs = 2000 +dt = 1e-2 +t0 = 0. + +ts, X = solve_ode('rk4', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=True) +# -------------------------------------------Numerical derivate calculation--------------------------------------------- +dX = jnp.array(np.gradient(jnp.array(X), jnp.array(ts), axis=0)) + +# ------------------------------------------- Create Library of features --------------------------------------------- +library_creator = PolynomialLibrary(poly_order=deg, include_bias=include_bias) +feature_lib, feature_names = library_creator.fit([X[:, i] for i in range(X.shape[1])]) + +# -------------------------------------------- Preprocessing ------------------------------------------- +min = -1 +max = 1 +new_rng = max - min +t_min, t_max = ts.min(), ts.max() +data_rng = ts.max() - ts.min() # ts_scaled ~ [-1, 1] - shape: (800, 1) +scale_ = data_rng / new_rng +ts_shifted = ts - t_min +ts_1centered = min + (ts_shifted / scale_) +ts_scaled = ts_1centered.reshape(ts.shape[0], 1) + +scale = scale_ / 2 +w_fill = w_fill * (scale / 2) # scale / inter +lr = lr * (scale * 0.5) # 0.01 # scale / inter +threshold = (threshold / scale) * (1 + prob) +# threshold = threshold / scale +# ################################################################################################## +# # System +# ################################################################################################## + +in_dim = ts_scaled.shape[1] +h1_dim = 16 +h2_dim = 16 +out_dim = X.shape[1] +batch_size = X.shape[0] +feat_dim = feature_lib.shape[1] +lasso_lmbda = 0. + +deepmod = DeepMoD(key=key, ts=ts[:, None], dict_dim=feat_dim, lib_creator=library_creator, + solver_name="l1", l1_ratio=0.5, eta=eta, + in_dim=in_dim, h1_dim=h1_dim, h2_dim=h2_dim, out_dim=out_dim, + batch_size=batch_size, threshold=threshold, scale=scale, + w_fill=w_fill, lr=lr, lmbda=lasso_lmbda) + + +coeff_track = 0 +for i in range(n_epochs): + coeff, loss_pred = deepmod.process(ts_scaled, X) + + print("\r >epoch={} L= {:.4f}| Sparse Weight: Wdx = {} | Wdy = {}".format(i, + loss_pred/T, + (deepmod.thresholding() * scale).T[0], + (deepmod.thresholding() * scale).T[1]), end="") + if i%100 == 0: + print() + + cov_cria = (coeff_track - coeff).mean() + coeff_track = coeff + if jnp.abs(cov_cria) <= 5e-8 or i==n_epochs-1: + print('model converget at', i, 'with coefficients \n', deepmod.thresholding().T) + break +print() + + + +print('done') + + + From 8e745b1b07aeb092e9c96e317a909c193e06c577 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Fri, 13 Dec 2024 09:54:25 +0700 Subject: [PATCH 2/2] normally, if we can import jax numpy directly, we should do it instead of importing the numpy module inside ngclearn --- exhibits/DeepMoD_PC/train_deepmod.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exhibits/DeepMoD_PC/train_deepmod.py b/exhibits/DeepMoD_PC/train_deepmod.py index 3a21f41..3ef5df3 100644 --- a/exhibits/DeepMoD_PC/train_deepmod.py +++ b/exhibits/DeepMoD_PC/train_deepmod.py @@ -1,7 +1,8 @@ import jax from jax import random, jit import numpy as np -from ngclearn import Context, numpy as jnp +from ngclearn import Context +import jax.numpy as jnp # print(jax.__version__) from deepmod import DeepMoD from ngclearn.utils.feature_dictionaries.polynomialLibrary import PolynomialLibrary