diff --git a/docs/images/museum/hgpc/HGPC_inputL.jpg b/docs/images/museum/hgpc/HGPC_inputL.jpg new file mode 100644 index 000000000..a907f8b09 Binary files /dev/null and b/docs/images/museum/hgpc/HGPC_inputL.jpg differ diff --git a/docs/images/museum/hgpc/Input_layer.png b/docs/images/museum/hgpc/Input_layer.png new file mode 100644 index 000000000..52bb25f9f Binary files /dev/null and b/docs/images/museum/hgpc/Input_layer.png differ diff --git a/docs/images/museum/hgpc/hgpc_model.png b/docs/images/museum/hgpc/hgpc_model.png new file mode 100644 index 000000000..036c82679 Binary files /dev/null and b/docs/images/museum/hgpc/hgpc_model.png differ diff --git a/docs/museum/pc-sindy.md b/docs/museum/pc-sindy.md new file mode 100644 index 000000000..9f4285182 --- /dev/null +++ b/docs/museum/pc-sindy.md @@ -0,0 +1,34 @@ +# Sparse Identification of Non-linear Dynamical Systems with Predictive Coding (PC-SINDy) + +In this section, we teach, create, simulate, and visualize SINDy model with Predictive Coding (PC-SINDy) using NGC-Learn library components. + + + + + + + +The model **code** for this exhibit can be found [here](https://github.com/NACLab/pc_sindy.py). + + +## Predictive Coding (PC) +PC is a biological plausible learning algorithm that learns the effective representation and transformations from the data. + + +## Sparse Identification of Non-linear Dynamical Systems (SINDy) +SINDy is a data-driven algorithm that finds the derivative of dynamical systems in terms of a symbolic equation of the system's state vector. +SINDy describes the derivative (linear operation acting on β–³t) as linear transformations +of a manually constructed dictionary from the state vector by a coefficient matrix. +Dictionary learning combined with LASSO (L1-norm) promotes the sparsity of the coefficient matrix +which allows only governing terms in the dictionary stay non-zero. + +Learning the + + +## SINDy with Predictive Coding + +

+ +

+ +## Predictive Coding Model Dynamics diff --git a/docs/museum/sparse_coding.md b/docs/museum/sparse_coding.md index f3e810373..d36dbf50a 100755 --- a/docs/museum/sparse_coding.md +++ b/docs/museum/sparse_coding.md @@ -68,7 +68,7 @@ $$ where we see that we aim to learn a two-layer generative system that specifically imposes a prior distribution `p(z)` over the latent feature detectors (via the -constraint function $\Omega\big(\mathbf{z}(t)\big)$) that we hope +constraint function $ \Omega\big(\mathbf{z}(t)\big) $ ) that we hope to extract in node `z`. Note that this two-layer model (or single latent-variable layer model) could either be the linear generative model from [1] or one similar to the model learned through ISTA [2] if a (soft) thresholding function is used instead. diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index afeea42ba..b00df24d2 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -177,7 +177,12 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit self.shape = shape self.n_units = n_units self.batch_size = batch_size - self.fx, self.dfx = create_function(fun_name=act_fx) + + + omega_0 = None + if act_fx == "sine": + omega_0 = kwargs["omega_0"] + self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) # compartments (state of the cell & parameters will be updated through stateless calls) restVals = jnp.zeros(_shape) diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index 69fc6dc8e..c833245c1 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -21,6 +21,6 @@ ## modulated synaptic components from .modulated.MSTDPETSynapse import MSTDPETSynapse ## patched synaptic components -from .patched.patchedSynapse import PatchedSynapse -from .patched.staticPatchedSynapse import StaticPatchedSynapse -from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse +## from .patched.patchedSynapse import PatchedSynapse +## from .patched.staticPatchedSynapse import StaticPatchedSynapse +## from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py deleted file mode 100644 index 935c1db9f..000000000 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ /dev/null @@ -1,316 +0,0 @@ -import matplotlib.pyplot as plt -from jax import random, numpy as jnp, jit -from functools import partial -from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn -from ngclearn import resolver, Component, Compartment -from ngclearn.components.synapses import PatchedSynapse -from ngclearn.utils import tensorstats - -@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) -def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1., w_decay=0., - pre_wght=1., post_wght=1.): - """ - Compute a tensor of adjustments to be applied to a synaptic value matrix. - - Args: - pre: pre-synaptic statistic to drive Hebbian update - - post: post-synaptic statistic to drive Hebbian update - - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: (Unused) - - signVal: multiplicative factor to modulate final update by (good for - flipping the signs of a computed synaptic change matrix) - - w_decay: synaptic decay factor to apply to this update - - pre_wght: pre-synaptic weighting term (Default: 1.) - - post_wght: post-synaptic weighting term (Default: 1.) - - Returns: - an update/adjustment matrix, an update adjustment vector (for biases) - """ - _pre = pre * pre_wght - _post = post * post_wght - dW = jnp.matmul(_pre.T, _post) - db = jnp.sum(_post, axis=0, keepdims=True) - if w_bound > 0.: - dW = dW * (w_bound - jnp.abs(W)) - if w_decay > 0.: - dW = dW - W * w_decay - - if w_mask!=None: - dW = dW * w_mask - - return dW * signVal, db * signVal - -@partial(jit, static_argnums=[1,2, 3]) -def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True): - """ - Enforces constraints that the (synaptic) efficacies/values within matrix - `W` must adhere to. - - Args: - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: ensure updated value matrix is strictly non-negative - - Returns: - the newly evolved synaptic weight value matrix - """ - _W = W - if w_bound > 0.: - if is_nonnegative == True: - _W = jnp.clip(_W, 0., w_bound) - else: - _W = jnp.clip(_W, -w_bound, w_bound) - - if w_mask!=None: - _W = _W * w_mask - - return _W - -class HebbianPatchedSynapse(PatchedSynapse): - """ - A synaptic cable that adjusts its efficacies via a two-factor Hebbian - adjustment rule. - - | --- Synapse Compartments: --- - | inputs - input (takes in external signals) - | outputs - output signals (transformation induced by synapses) - | weights - current value matrix of synaptic efficacies - | biases - current value vector of synaptic bias values - | key - JAX PRNG key - | --- Synaptic Plasticity Compartments: --- - | pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals) - | post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals) - | dWweights - current delta matrix containing changes to be applied to synaptic efficacies - | dBiases - current delta vector containing changes to be applied to bias values - | opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used) - - Args: - name: the string name of this cell - - shape: tuple specifying shape of this synaptic cable (usually a 2-tuple - with number of inputs by number of outputs) - - eta: global learning rate - - weight_init: a kernel to drive initialization of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - initialization to use - - bias_init: a kernel to drive initialization of biases for this synaptic cable - (Default: None, which turns off/disables biases) - - w_bound: maximum weight to softly bound this cable's value matrix to; if - set to 0, then no synaptic value bounding will be applied - - is_nonnegative: enforce that synaptic efficacies are always non-negative - after each synaptic update (if False, no constraint will be applied) - - w_decay: degree to which (L2) synaptic weight decay is applied to the - computed Hebbian adjustment (Default: 0); note that decay is not - applied to any configured biases - - sign_value: multiplicative factor to apply to final synaptic update before - it is applied to synapses; this is useful if gradient descent style - optimization is required (as Hebbian rules typically yield - adjustments for ascent) - - optim_type: optimization scheme to physically alter synaptic values - once an update is computed (Default: "sgd"); supported schemes - include "sgd" and "adam" - - :Note: technically, if "sgd" or "adam" is used but `signVal = 1`, - then the ascent form of each rule is employed (signVal = -1) or - a negative learning rate will mean a descent form of the - `optim_scheme` is being employed - - pre_wght: pre-synaptic weighting factor (Default: 1.) - - post_wght: post-synaptic weighting factor (Default: 1.) - - resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b - - p_conn: probability of a connection existing (default: 1.); setting - this to < 1. will result in a sparser synaptic structure - """ - - def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None, - w_mask=None, w_bound=1., is_nonnegative=False, w_decay=0., sign_value=1., - optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., - resist_scale=1., batch_size=1, **kwargs): - super().__init__(name, shape, n_sub_models, stride_shape, w_mask, weight_init, bias_init, resist_scale, - p_conn, batch_size=batch_size, **kwargs) - - self.n_sub_models = n_sub_models - self.sub_stride = stride_shape - - self.shape = (shape[0] + (2 * stride_shape[0]), - shape[1] + (2 * stride_shape[1])) - self.sub_shape = (shape[0]//n_sub_models + (2 * stride_shape[0]), - shape[1]//n_sub_models + (2* stride_shape[1])) - - ## synaptic plasticity properties and characteristics - self.Rscale = resist_scale - self.w_bound = w_bound - self.w_decay = w_decay ## synaptic decay - self.pre_wght = pre_wght - self.post_wght = post_wght - self.eta = eta - self.is_nonnegative = is_nonnegative - self.sign_value = sign_value - - ## optimization / adjustment properties (given learning dynamics above) - self.opt = get_opt_step_fn(optim_type, eta=self.eta) - - # compartments (state of the cell, parameters, will be updated through stateless calls) - self.preVals = jnp.zeros((self.batch_size, self.shape[0])) - self.postVals = jnp.zeros((self.batch_size, self.shape[1])) - self.pre = Compartment(self.preVals) - self.post = Compartment(self.postVals) - self.w_mask = w_mask - self.dWeights = Compartment(jnp.zeros(self.shape)) - self.dBiases = Compartment(jnp.zeros(self.shape[1])) - - #key, subkey = random.split(self.key.value) - self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) - - @staticmethod - def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, w_decay, pre_wght, - post_wght, pre, post, weights): - ## calculate synaptic update values - dW, db = _calc_update( - pre, post, weights, w_mask, w_bound, is_nonnegative=is_nonnegative, - signVal=sign_value, w_decay=w_decay, pre_wght=pre_wght, - post_wght=post_wght) - - return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db - - @staticmethod - def _evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, w_decay, pre_wght, - post_wght, bias_init, pre, post, weights, biases, opt_params): - ## calculate synaptic update values - dWeights, dBiases = HebbianPatchedSynapse._compute_update( - w_mask, w_bound, is_nonnegative, sign_value, w_decay, - pre_wght, post_wght, pre, post, weights - ) - ## conduct a step of optimization - get newly evolved synaptic weight value matrix - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) - else: - # ignore db since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) - ## ensure synaptic efficacies adhere to constraints - weights = _enforce_constraints(weights, w_mask, w_bound, is_nonnegative=is_nonnegative) - return opt_params, weights, biases, dWeights, dBiases - - @resolver(_evolve) - def evolve(self, opt_params, weights, biases, dWeights, dBiases): - self.opt_params.set(opt_params) - self.weights.set(weights) - self.biases.set(biases) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - @staticmethod - def _reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - jnp.zeros(shape), # dW - jnp.zeros(shape[1]), # db - ) - - @classmethod - def help(cls): ## component help function - properties = { - "synapse_type": "HebbianSynapse - performs an adaptable synaptic " - "transformation of inputs to produce output signals; " - "synapses are adjusted via two-term/factor Hebbian adjustment" - } - compartment_props = { - "inputs": - {"inputs": "Takes in external input signal values", - "pre": "Pre-synaptic statistic for Hebb rule (z_j)", - "post": "Post-synaptic statistic for Hebb rule (z_i)"}, - "states": - {"weights": "Synapse efficacy/strength parameter values", - "biases": "Base-rate/bias parameter values", - "key": "JAX PRNG key"}, - "analytics": - {"dWeights": "Synaptic weight value adjustment matrix produced at time t", - "dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"}, - "outputs": - {"outputs": "Output of synaptic transformation"}, - } - hyperparams = { - "shape": "Overall shape of synaptic weight value matrix; number inputs x number outputs", - "n_sub_models": "The number of submodels in each layer", - "stride_shape": "Stride shape of overlapping synaptic weight value matrix", - "batch_size": "Batch size dimension of this component", - "weight_init": "Initialization conditions for synaptic weight (W) values", - "bias_init": "Initialization conditions for bias/base-rate (b) values", - "resist_scale": "Resistance level scaling factor (applied to output of transformation)", - "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", - "is_nonnegative": "Should synapses be constrained to be non-negative post-updates?", - "sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0", - "eta": "Global (fixed) learning rate", - "pre_wght": "Pre-synaptic weighting coefficient (q_pre)", - "post_wght": "Post-synaptic weighting coefficient (q_post)", - "w_bound": "Soft synaptic bound applied to synapses post-update", - "w_decay": "Synaptic decay term", - "optim_type": "Choice of optimizer to adjust synaptic weights" - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "outputs = [(W * Rscale) * inputs] + b ;" - "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay", - "hyperparameters": hyperparams} - return info - - @resolver(_reset) - def reset(self, inputs, outputs, pre, post, dWeights, dBiases): - self.inputs.set(inputs) - self.outputs.set(outputs) - self.pre.set(pre) - self.post.set(post) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - Wab = HebbianPatchedSynapse("Wab", (9, 30), 3) - print(Wab) - plt.imshow(Wab.weights.value, cmap='gray') - plt.show() diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py deleted file mode 100644 index ef22e06ee..000000000 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ /dev/null @@ -1,192 +0,0 @@ -import matplotlib.pyplot as plt -from jax import random, numpy as jnp, jit -from ngclearn import resolver, Component, Compartment -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils import tensorstats -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -import math - - -""" - 𝑳𝒾 𝑳𝑗 - β¬‡οΈŽ β¬‡οΈŽ - 𝒏𝒾 = 3, 𝑫𝒾 𝑫𝑗 - 𝒅𝒾 = 𝑫𝒾 / 𝒏𝒾 𝒅𝑗 = 𝑫𝑗 / 𝒏𝒾 - - ⎯[𝑀𝒾𝑗¹ 𝟘 𝟘 ]⎯ - Z𝒾 ⎯[ 𝟘 𝑀𝒾𝑗² 𝟘 ]⎯ Z𝑗 - ⎯[ 𝟘 𝟘 𝑀𝒾𝑗³]⎯ - -""" - - -def create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init): - sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models) - di, dj = sub_shape - si, sj = sub_stride - - weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj) - weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True) - - for i in range(n_sub_models): - start_i = i * di - end_i = (i + 1) * di + 2 * si - start_j = i * dj - end_j = (i + 1) * dj + 2 * sj - - shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj) - - weights[start_i : end_i, - start_j : end_j] = initialize_params(key[2], - init_kernel=weight_init, - shape=shape_, - use_numpy=True) - if si!=0: - weights[:si,:] = 0. - weights[-si:,:] = 0. - if sj!=0: - weights[:,:sj] = 0. - weights[:, -sj:] = 0. - - return weights - - - -class PatchedSynapse(JaxComponent): ## base patched synaptic cable - # Define Functions - def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), w_mask=None, weight_init=None, bias_init=None, - resist_scale=1., p_conn=1., batch_size=1, **kwargs): - super().__init__(name, **kwargs) - - self.Rscale = resist_scale - self.batch_size = batch_size - self.weight_init = weight_init - self.bias_init = bias_init - - self.n_sub_models = n_sub_models - self.sub_stride = stride_shape - - tmp_key, *subkeys = random.split(self.key.value, 4) - if self.weight_init is None: - info(self.name, "is using default weight initializer!") - self.weight_init = {"dist": "fan_in_gaussian"} - - weights = create_multi_patch_synapses(key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride, - weight_init=self.weight_init) - - self.w_mask = jnp.where(weights!=0, 1, 0) - self.sub_shape = (shape[0]//n_sub_models, shape[1]//n_sub_models) - - self.shape = weights.shape - self.sub_shape = self.sub_shape[0]+(2*self.sub_stride[0]), self.sub_shape[1]+(2*self.sub_stride[1]) - - if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed - mask = random.bernoulli(subkeys[1], p=p_conn, shape=self.shape) - weights = weights * mask ## sparsify matrix - - ## Compartment setup - preVals = jnp.zeros((self.batch_size, self.shape[0])) - postVals = jnp.zeros((self.batch_size, self.shape[1])) - self.inputs = Compartment(preVals) - self.outputs = Compartment(postVals) - self.weights = Compartment(weights) - - ## Set up (optional) bias values - if self.bias_init is None: - info(self.name, "is using default bias value of zero (no bias " - "kernel provided)!") - self.biases = Compartment(initialize_params(subkeys[2], bias_init, - (1, self.shape[1])) - if bias_init else 0.0) - @staticmethod - def _advance_state(Rscale, inputs, weights, biases): - outputs = (jnp.matmul(inputs, weights) * Rscale) + biases - return outputs - - @resolver(_advance_state) - def advance_state(self, outputs): - self.outputs.set(outputs) - - @staticmethod - def _reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - return inputs, outputs - - @resolver(_reset) - def reset(self, inputs, outputs): - self.inputs.set(inputs) - self.outputs.set(outputs) - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, - biases=self.biases.value) - else: - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - if "biases" in data.keys(): - self.biases.set(data['biases']) - - @classmethod - def help(cls): ## component help function - properties = { - "synapse_type": "PatchedSynapse - performs a synaptic transformation " - "of inputs to produce output signals (e.g., a " - "scaled linear multivariate transformation)" - } - compartment_props = { - "inputs": - {"inputs": "Takes in external input signal values"}, - "states": - {"weights": "Synapse efficacy/strength parameter values", - "biases": "Base-rate/bias parameter values", - "key": "JAX PRNG key"}, - "outputs": - {"outputs": "Output of synaptic transformation"}, - } - hyperparams = { - "shape": "Overall shape of synaptic weight value matrix; number inputs x number outputs", - "n_sub_models": "The number of submodels in each layer", - "stride_shape": "Stride shape of overlapping synaptic weight value matrix", - "batch_size": "Batch size dimension of this component", - "weight_init": "Initialization conditions for synaptic weight (W) values", - "bias_init": "Initialization conditions for bias/base-rate (b) values", - "resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation", - "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)" - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "outputs = [W * inputs] * Rscale + b", - "hyperparameters": hyperparams} - return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - Wab = PatchedSynapse("Wab", (9, 30), 3) - print(Wab) - plt.imshow(Wab.weights.value, cmap='gray') - plt.show() diff --git a/ngclearn/components/synapses/patched/staticPatchedSynapse.py b/ngclearn/components/synapses/patched/staticPatchedSynapse.py deleted file mode 100644 index 6cbcf988e..000000000 --- a/ngclearn/components/synapses/patched/staticPatchedSynapse.py +++ /dev/null @@ -1,4 +0,0 @@ -from .patchedSynapse import PatchedSynapse - -class StaticPatchedSynapse(PatchedSynapse): - pass \ No newline at end of file diff --git a/ngclearn/utils/diffeq/ode_functions.py b/ngclearn/utils/diffeq/ode_functions.py new file mode 100644 index 000000000..037ee9d1d --- /dev/null +++ b/ngclearn/utils/diffeq/ode_functions.py @@ -0,0 +1,36 @@ +import jax.numpy as jnp +import jax +from jax import jit +from functools import partial +import matplotlib.pyplot as plt + +''' + +x0 = jnp.array([3, -1.5]) +''' +@partial(jit, static_argnums=(0,)) +def linear_2D(t, x, params): + ''' + :param x: 2D vector + type: jax array + shape:(2,) + + :param t: Unused + + :param params: Unused + + :return: 2D vector: [ + -0.1 * x[0] + 2.0 * x[1], + -2.0 * x[0] - 0.1 * x[1] + ] + type: jax array + shape:(2,) + + ------------------------------------------ + * suggested init value- + x0 = jnp.array([3, -1.5]) + ''' + coeff = jnp.array([[-0.1, 2], [-2, -0.1]]).T + dfx_ = jnp.matmul(x, coeff) + + return dfx_ \ No newline at end of file diff --git a/ngclearn/utils/diffeq/ode_solver.py b/ngclearn/utils/diffeq/ode_solver.py new file mode 100644 index 000000000..565043eb3 --- /dev/null +++ b/ngclearn/utils/diffeq/ode_solver.py @@ -0,0 +1,292 @@ +import jax.numpy as jnp +from jax import jit +from functools import partial +from jax.lax import scan as _scan + + + +@jit +def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition + sum = 0 + + for arg, val in zip(args, kwargs.values()): + sum = sum + val * arg + return sum + +@partial(jit, static_argnums=(3, 4)) +def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine + _t = t + dt + _x = x * x_scale + dx_dt * dt + return _t, _x + +@partial(jit, static_argnums=(1, 2, 3, 4, )) +def euler(carry, dfx, dt, params, x_scale=1.): + """ + Iteratively integrates one step forward via the Euler method, i.e., a + first-order Runge-Kutta (RK-1) step. + + Args: + t: current time variable to advance by dt + + x: current variable values to advance/iteratively integrate (at time `t`) + + dfx: (ordinary) differential equation co-routine (as implemented in an + ngc-learn component) + + dt: integration time step (also referred to as `h` in mathematics) + + params: tuple containing configuration values/hyper-parameters for the + (ordinary) differential equation an ngc-learn component will provide + + x_scale: dampening factor to scale `x` by (Default: 1) + + Returns: + variable values iteratively integrated/advanced to next step (`t + dt`) + """ + t, x = carry + + dx_dt = dfx(t, x, params) + _t, _x = _step_forward(t, x, dx_dt, dt, x_scale) + + new_carry = (_t, _x) + return new_carry, (new_carry, carry) + +@partial(jit, static_argnums=(1, 2, 3, 4, )) +def heun(carry, dfx, dt, params, x_scale=1.): + """ + Iteratively integrates one step forward via Heun's method, i.e., a + second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes + two (differential) function evaluations to estimate the solution at a given + point in time. + (Note: ngc-learn internally recognizes "rk2_heun" or "heun" for this routine) + + | Reference: + | Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary + | differential equations and differential-algebraic equations. Society for + | Industrial and Applied Mathematics, 1998. + + Args: + t: current time variable to advance by dt + + x: current variable values to advance/iteratively integrate (at time `t`) + + dfx: (ordinary) differential equation co-routine (as implemented in an + ngc-learn component) + + dt: integration time step (also referred to as `h` in mathematics) + + params: tuple containing configuration values/hyper-parameters for the + (ordinary) differential equation an ngc-learn component will provide + + x_scale: dampening factor to scale `x` by (Default: 1) + + Returns: + variable values iteratively integrated/advanced to next step (`t + dt`) + """ + t, x = carry + + dx_dt = dfx(t, x, params) + _t, _x = _step_forward(t, x, dx_dt, dt, x_scale) + _dx_dt = dfx(_t, _x, params) + summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1) + _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale) + + new_carry = (_t, _x) + return new_carry, (new_carry, carry) + +@partial(jit, static_argnums=(1, 2, 3, 4, )) +def rk2(carry, dfx, dt, params, x_scale=1.): + """ + Iteratively integrates one step forward via the midpoint method, i.e., a + second-order Runge-Kutta (RK-2) step. + (Note: ngc-learn internally recognizes "rk2" or "midpoint" for this routine) + + | Reference: + | Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary + | differential equations and differential-algebraic equations. Society for + | Industrial and Applied Mathematics, 1998. + + Args: + t: current time variable to advance by dt + + x: current variable values to advance/iteratively integrate (at time `t`) + + dfx: (ordinary) differential equation co-routine (as implemented in an + ngc-learn component) + + dt: integration time step (also referred to as `h` in mathematics) + + params: tuple containing configuration values/hyper-parameters for the + (ordinary) differential equation an ngc-learn component will provide + + x_scale: dampening factor to scale `x` by (Default: 1) + + Returns: + variable values iteratively integrated/advanced to next step (`t + dt`) + """ + t, x = carry + + f_1 = dfx(t, x, params) + t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale) + f_2 = dfx(t1, x1, params) + _t, _x = _step_forward(t, x, f_2, dt, x_scale) + + new_carry = (_t, _x) + return new_carry, (new_carry, carry) + +@partial(jit, static_argnums=(1, 2, 3, 4, )) +def rk4(carry, dfx, dt, params, x_scale=1.): + """ + Iteratively integrates one step forward via the midpoint method, i.e., a + fourth-order Runge-Kutta (RK-4) step. + (Note: ngc-learn internally recognizes "rk4" or this routine) + + | Reference: + | Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary + | differential equations and differential-algebraic equations. Society for + | Industrial and Applied Mathematics, 1998. + + Args: + t: current time variable to advance by dt + + x: current variable values to advance/iteratively integrate (at time `t`) + + dfx: (ordinary) differential equation co-routine (as implemented in an + ngc-learn component) + + dt: integration time step (also referred to as `h` in mathematics) + + params: tuple containing configuration values/hyper-parameters for the + (ordinary) differential equation an ngc-learn component will provide + + x_scale: dampening factor to scale `x` by (Default: 1) + + Returns: + variable values iteratively integrated/advanced to next step (`t + dt`) + """ + + t, x = carry + + f_1 = dfx(t, x, params) + t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale) + + f_2 = dfx(t1, x1, params) + t2, x2 = _step_forward(t, x, f_2, dt * 0.5, x_scale) + + f_3 = dfx(t2, x2, params) + t3, x3 = _step_forward(t, x, f_3, dt, x_scale) + + f_4 = dfx(t3, x3, params) + + _dx_dt = _sum_combine(f_1, f_2, f_3, f_4, w_f1=1, w_f2=2, w_f3=2, w_f4=1) + _t, _x = _step_forward(t, x, _dx_dt, dt / 6, x_scale) + + new_carry = (_t, _x) + return new_carry, (new_carry, carry) + +@partial(jit, static_argnums=(1, 2, 3, 4,)) +def ralston(carry, dfx, dt, params, x_scale=1.): + """ + Iteratively integrates one step forward via Ralston's method, i.e., a + second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes + two (differential) function evaluations to estimate the solution at a given + point in time. + (Note: ngc-learn internally recognizes "rk2_ralston" or "ralston" for this + routine) + + | Reference: + | Ralston, Anthony. "Runge-Kutta methods with minimum error bounds." + | Mathematics of computation 16.80 (1962): 431-437. + + Args: + t: current time variable to advance by dt + + x: current variable values to advance/iteratively integrate (at time `t`) + + dfx: (ordinary) differential equation co-routine (as implemented in an + ngc-learn component) + + dt: integration time step (also referred to as `h` in mathematics) + + params: tuple containing configuration values/hyper-parameters for the + (ordinary) differential equation an ngc-learn component will provide + + x_scale: dampening factor to scale `x` by (Default: 1) + + Returns: + variable values iteratively integrated/advanced to next step (`t + dt`) + """ + + t, x = carry + + dx_dt = dfx(t, x, params) ## k1 + tm, xm = _step_forward(t, x, dx_dt, dt * 0.75, x_scale) + _dx_dt = dfx(tm, xm, params) ## k2 + ## Note: new step is a weighted combination of k1 and k2 + summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=(1./3.), weight2=(2./3.)) + _t, _x = _step_forward(t, x, summed_dx_dt, dt, x_scale) + + new_carry = (_t, _x) + return new_carry, (new_carry, carry) + + +@partial(jit, static_argnums=(0, 3, 4, 5, 6, 7, 8)) +def solve_ode(method_name, t0, x0, T, dfx, dt, params=None, x_scale=1., sols_only=True): + + if method_name =='euler': + method = euler + elif method_name == 'heun': + method = heun + elif method_name == 'rk2': + method = rk2 + elif method_name =='rk4': + method = rk4 + elif method_name =='ralston': + method = ralston + + def scanner(carry, _): + return method(carry, dfx, dt, params, x_scale) + + x_T, (xs_next, xs_carry) = _scan(scanner, init=(t0, x0), xs=jnp.arange(T)) + + if not sols_only: + return x_T, xs_next, xs_carry + + return xs_next + + + +######################################################################################## +######################################################################################## +if __name__ == '__main__': + import matplotlib.pyplot as plt + from ode_functions import linear_2D + + dfx = linear_2D + x0 = jnp.array([3, -1.5]) + + dt = 1e-2 + t0 = 0. + T = 800 + + (t_final, x_final), (ts_sol, sol_euler), (ts_carr, xs_carr) = solve_ode('euler', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False) + (_, x_final), (_, sol_heun), (_, xs_carr) = solve_ode('heun', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False) + (_, x_final), (_, sol_rk2), (_, xs_carr) = solve_ode('rk2', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False) + (_, x_final), (_, sol_rk4), (_, xs_carr) = solve_ode('rk4', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False) + (_, x_final), (_, sol_ralston), (_, xs_carr) = solve_ode('ralston', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False) + + + plt.plot(ts_sol, sol_euler[:, 0], label='x0-Euler') + plt.plot(ts_sol, sol_heun[:, 0], label='x0-Heun') + plt.plot(ts_sol, sol_rk2[:, 0], label='x0-RK2') + plt.plot(ts_sol, sol_rk4[:, 0], label='x0-RK4') + plt.plot(ts_sol, sol_ralston[:, 0], label='x0-Ralston') + + plt.plot(ts_sol, sol_euler[:, 1], label='x1-Euler') + plt.plot(ts_sol, sol_heun[:, 1], label='x1-Heun') + plt.plot(ts_sol, sol_rk2[:, 1], label='x1-RK2') + plt.plot(ts_sol, sol_rk4[:, 1], label='x1-RK4') + plt.plot(ts_sol, sol_ralston[:, 1], label='x1-Ralston') + plt.legend(loc='best') + plt.grid() + plt.show() diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py index 460fd2f48..0a55a55da 100755 --- a/ngclearn/utils/diffeq/ode_utils.py +++ b/ngclearn/utils/diffeq/ode_utils.py @@ -44,8 +44,9 @@ def get_integrator_code(integrationType): @jit def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition sum = 0 + for arg, val in zip(args, kwargs.values()): - sum += arg * val + sum = sum + val * arg return sum @jit @@ -54,6 +55,7 @@ def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine _x = x * x_scale + dx_dt * dt return _t, _x +@partial(jit, static_argnums=(2, )) def step_euler(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via the Euler method, i.e., a @@ -81,6 +83,7 @@ def step_euler(t, x, dfx, dt, params, x_scale=1.): _t, _x = _step_forward(t, x, dx_dt, dt, x_scale) return _t, _x +@partial(jit, static_argnums=(2, )) def step_heun(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via Heun's method, i.e., a @@ -113,13 +116,15 @@ def step_heun(t, x, dfx, dt, params, x_scale=1.): variable values iteratively integrated/advanced to next step (`t + dt`) """ dx_dt = dfx(t, x, params) + _t, _x = _step_forward(t, x, dx_dt, dt, x_scale) _dx_dt = dfx(_t, _x, params) summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1) + _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale) return _t, _x - +@partial(jit, static_argnums=(2, )) def step_rk2(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via the midpoint method, i.e., a @@ -150,13 +155,13 @@ def step_rk2(t, x, dfx, dt, params, x_scale=1.): variable values iteratively integrated/advanced to next step (`t + dt`) """ f_1 = dfx(t, x, params) - t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale) + t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale) f_2 = dfx(t1, x1, params) _t, _x = _step_forward(t, x, f_2, dt, x_scale) return _t, _x - +@partial(jit, static_argnums=(2, )) def step_rk4(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via the midpoint method, i.e., a @@ -201,8 +206,7 @@ def step_rk4(t, x, dfx, dt, params, x_scale=1.): _t, _x = _step_forward(t, x, _dx_dt, dt / 6, x_scale) return _t, _x - - +@partial(jit, static_argnums=(2, )) def step_ralston(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via Ralston's method, i.e., a diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index 59bc8a32a..704ec7e0f 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -83,6 +83,10 @@ def create_function(fun_name, args=None): if fun_name == "tanh": fx = tanh dfx = d_tanh + elif fun_name == "sine": + fx = sine + dfx = d_sine + omega_0 = args elif fun_name == "sigmoid": fx = sigmoid dfx = d_sigmoid @@ -266,6 +270,34 @@ def d_relu(x): """ return (x >= 0.).astype(jnp.float32) +@jit +def sine(x, omega_0=30): + """ + f(x) = sin(x * omega_0). + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ + return jnp.sin(omega_0 * x) + +@jit +def d_sine(x, omega_0=30): + """ + frequency = omega_0 + frequency * cos(x * frequency). + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ + return omega_0 * jnp.cos(omega_0 * x) + + @jit def tanh(x): """