diff --git a/docs/images/museum/hgpc/HGPC_inputL.jpg b/docs/images/museum/hgpc/HGPC_inputL.jpg
new file mode 100644
index 000000000..a907f8b09
Binary files /dev/null and b/docs/images/museum/hgpc/HGPC_inputL.jpg differ
diff --git a/docs/images/museum/hgpc/Input_layer.png b/docs/images/museum/hgpc/Input_layer.png
new file mode 100644
index 000000000..52bb25f9f
Binary files /dev/null and b/docs/images/museum/hgpc/Input_layer.png differ
diff --git a/docs/images/museum/hgpc/hgpc_model.png b/docs/images/museum/hgpc/hgpc_model.png
new file mode 100644
index 000000000..036c82679
Binary files /dev/null and b/docs/images/museum/hgpc/hgpc_model.png differ
diff --git a/docs/museum/pc-sindy.md b/docs/museum/pc-sindy.md
new file mode 100644
index 000000000..9f4285182
--- /dev/null
+++ b/docs/museum/pc-sindy.md
@@ -0,0 +1,34 @@
+# Sparse Identification of Non-linear Dynamical Systems with Predictive Coding (PC-SINDy)
+
+In this section, we teach, create, simulate, and visualize SINDy model with Predictive Coding (PC-SINDy) using NGC-Learn library components.
+
+
+
+
+
+
+
+The model **code** for this exhibit can be found [here](https://github.com/NACLab/pc_sindy.py).
+
+
+## Predictive Coding (PC)
+PC is a biological plausible learning algorithm that learns the effective representation and transformations from the data.
+
+
+## Sparse Identification of Non-linear Dynamical Systems (SINDy)
+SINDy is a data-driven algorithm that finds the derivative of dynamical systems in terms of a symbolic equation of the system's state vector.
+SINDy describes the derivative (linear operation acting on β³t) as linear transformations
+of a manually constructed dictionary from the state vector by a coefficient matrix.
+Dictionary learning combined with LASSO (L1-norm) promotes the sparsity of the coefficient matrix
+which allows only governing terms in the dictionary stay non-zero.
+
+Learning the
+
+
+## SINDy with Predictive Coding
+
+
+
+
+
+## Predictive Coding Model Dynamics
diff --git a/docs/museum/sparse_coding.md b/docs/museum/sparse_coding.md
index f3e810373..d36dbf50a 100755
--- a/docs/museum/sparse_coding.md
+++ b/docs/museum/sparse_coding.md
@@ -68,7 +68,7 @@ $$
where we see that we aim to learn a two-layer generative system that specifically
imposes a prior distribution `p(z)` over the latent feature detectors (via the
-constraint function $\Omega\big(\mathbf{z}(t)\big)$) that we hope
+constraint function $ \Omega\big(\mathbf{z}(t)\big) $ ) that we hope
to extract in node `z`. Note that this two-layer model (or single latent-variable layer
model) could either be the linear generative model from [1] or one similar to the
model learned through ISTA [2] if a (soft) thresholding function is used instead.
diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py
index afeea42ba..b00df24d2 100755
--- a/ngclearn/components/neurons/graded/rateCell.py
+++ b/ngclearn/components/neurons/graded/rateCell.py
@@ -177,7 +177,12 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
self.shape = shape
self.n_units = n_units
self.batch_size = batch_size
- self.fx, self.dfx = create_function(fun_name=act_fx)
+
+
+ omega_0 = None
+ if act_fx == "sine":
+ omega_0 = kwargs["omega_0"]
+ self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0)
# compartments (state of the cell & parameters will be updated through stateless calls)
restVals = jnp.zeros(_shape)
diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py
index 69fc6dc8e..c833245c1 100644
--- a/ngclearn/components/synapses/__init__.py
+++ b/ngclearn/components/synapses/__init__.py
@@ -21,6 +21,6 @@
## modulated synaptic components
from .modulated.MSTDPETSynapse import MSTDPETSynapse
## patched synaptic components
-from .patched.patchedSynapse import PatchedSynapse
-from .patched.staticPatchedSynapse import StaticPatchedSynapse
-from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse
+## from .patched.patchedSynapse import PatchedSynapse
+## from .patched.staticPatchedSynapse import StaticPatchedSynapse
+## from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse
diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py
deleted file mode 100644
index 935c1db9f..000000000
--- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py
+++ /dev/null
@@ -1,316 +0,0 @@
-import matplotlib.pyplot as plt
-from jax import random, numpy as jnp, jit
-from functools import partial
-from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
-from ngclearn import resolver, Component, Compartment
-from ngclearn.components.synapses import PatchedSynapse
-from ngclearn.utils import tensorstats
-
-@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
-def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1., w_decay=0.,
- pre_wght=1., post_wght=1.):
- """
- Compute a tensor of adjustments to be applied to a synaptic value matrix.
-
- Args:
- pre: pre-synaptic statistic to drive Hebbian update
-
- post: post-synaptic statistic to drive Hebbian update
-
- W: synaptic weight values (at time t)
-
- w_bound: maximum value to enforce over newly computed efficacies
-
- is_nonnegative: (Unused)
-
- signVal: multiplicative factor to modulate final update by (good for
- flipping the signs of a computed synaptic change matrix)
-
- w_decay: synaptic decay factor to apply to this update
-
- pre_wght: pre-synaptic weighting term (Default: 1.)
-
- post_wght: post-synaptic weighting term (Default: 1.)
-
- Returns:
- an update/adjustment matrix, an update adjustment vector (for biases)
- """
- _pre = pre * pre_wght
- _post = post * post_wght
- dW = jnp.matmul(_pre.T, _post)
- db = jnp.sum(_post, axis=0, keepdims=True)
- if w_bound > 0.:
- dW = dW * (w_bound - jnp.abs(W))
- if w_decay > 0.:
- dW = dW - W * w_decay
-
- if w_mask!=None:
- dW = dW * w_mask
-
- return dW * signVal, db * signVal
-
-@partial(jit, static_argnums=[1,2, 3])
-def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
- """
- Enforces constraints that the (synaptic) efficacies/values within matrix
- `W` must adhere to.
-
- Args:
- W: synaptic weight values (at time t)
-
- w_bound: maximum value to enforce over newly computed efficacies
-
- is_nonnegative: ensure updated value matrix is strictly non-negative
-
- Returns:
- the newly evolved synaptic weight value matrix
- """
- _W = W
- if w_bound > 0.:
- if is_nonnegative == True:
- _W = jnp.clip(_W, 0., w_bound)
- else:
- _W = jnp.clip(_W, -w_bound, w_bound)
-
- if w_mask!=None:
- _W = _W * w_mask
-
- return _W
-
-class HebbianPatchedSynapse(PatchedSynapse):
- """
- A synaptic cable that adjusts its efficacies via a two-factor Hebbian
- adjustment rule.
-
- | --- Synapse Compartments: ---
- | inputs - input (takes in external signals)
- | outputs - output signals (transformation induced by synapses)
- | weights - current value matrix of synaptic efficacies
- | biases - current value vector of synaptic bias values
- | key - JAX PRNG key
- | --- Synaptic Plasticity Compartments: ---
- | pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals)
- | post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals)
- | dWweights - current delta matrix containing changes to be applied to synaptic efficacies
- | dBiases - current delta vector containing changes to be applied to bias values
- | opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used)
-
- Args:
- name: the string name of this cell
-
- shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
- with number of inputs by number of outputs)
-
- eta: global learning rate
-
- weight_init: a kernel to drive initialization of this synaptic cable's values;
- typically a tuple with 1st element as a string calling the name of
- initialization to use
-
- bias_init: a kernel to drive initialization of biases for this synaptic cable
- (Default: None, which turns off/disables biases)
-
- w_bound: maximum weight to softly bound this cable's value matrix to; if
- set to 0, then no synaptic value bounding will be applied
-
- is_nonnegative: enforce that synaptic efficacies are always non-negative
- after each synaptic update (if False, no constraint will be applied)
-
- w_decay: degree to which (L2) synaptic weight decay is applied to the
- computed Hebbian adjustment (Default: 0); note that decay is not
- applied to any configured biases
-
- sign_value: multiplicative factor to apply to final synaptic update before
- it is applied to synapses; this is useful if gradient descent style
- optimization is required (as Hebbian rules typically yield
- adjustments for ascent)
-
- optim_type: optimization scheme to physically alter synaptic values
- once an update is computed (Default: "sgd"); supported schemes
- include "sgd" and "adam"
-
- :Note: technically, if "sgd" or "adam" is used but `signVal = 1`,
- then the ascent form of each rule is employed (signVal = -1) or
- a negative learning rate will mean a descent form of the
- `optim_scheme` is being employed
-
- pre_wght: pre-synaptic weighting factor (Default: 1.)
-
- post_wght: post-synaptic weighting factor (Default: 1.)
-
- resist_scale: a fixed scaling factor to apply to synaptic transform
- (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b
-
- p_conn: probability of a connection existing (default: 1.); setting
- this to < 1. will result in a sparser synaptic structure
- """
-
- def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
- w_mask=None, w_bound=1., is_nonnegative=False, w_decay=0., sign_value=1.,
- optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
- resist_scale=1., batch_size=1, **kwargs):
- super().__init__(name, shape, n_sub_models, stride_shape, w_mask, weight_init, bias_init, resist_scale,
- p_conn, batch_size=batch_size, **kwargs)
-
- self.n_sub_models = n_sub_models
- self.sub_stride = stride_shape
-
- self.shape = (shape[0] + (2 * stride_shape[0]),
- shape[1] + (2 * stride_shape[1]))
- self.sub_shape = (shape[0]//n_sub_models + (2 * stride_shape[0]),
- shape[1]//n_sub_models + (2* stride_shape[1]))
-
- ## synaptic plasticity properties and characteristics
- self.Rscale = resist_scale
- self.w_bound = w_bound
- self.w_decay = w_decay ## synaptic decay
- self.pre_wght = pre_wght
- self.post_wght = post_wght
- self.eta = eta
- self.is_nonnegative = is_nonnegative
- self.sign_value = sign_value
-
- ## optimization / adjustment properties (given learning dynamics above)
- self.opt = get_opt_step_fn(optim_type, eta=self.eta)
-
- # compartments (state of the cell, parameters, will be updated through stateless calls)
- self.preVals = jnp.zeros((self.batch_size, self.shape[0]))
- self.postVals = jnp.zeros((self.batch_size, self.shape[1]))
- self.pre = Compartment(self.preVals)
- self.post = Compartment(self.postVals)
- self.w_mask = w_mask
- self.dWeights = Compartment(jnp.zeros(self.shape))
- self.dBiases = Compartment(jnp.zeros(self.shape[1]))
-
- #key, subkey = random.split(self.key.value)
- self.opt_params = Compartment(get_opt_init_fn(optim_type)(
- [self.weights.value, self.biases.value]
- if bias_init else [self.weights.value]))
-
- @staticmethod
- def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
- post_wght, pre, post, weights):
- ## calculate synaptic update values
- dW, db = _calc_update(
- pre, post, weights, w_mask, w_bound, is_nonnegative=is_nonnegative,
- signVal=sign_value, w_decay=w_decay, pre_wght=pre_wght,
- post_wght=post_wght)
-
- return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db
-
- @staticmethod
- def _evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
- post_wght, bias_init, pre, post, weights, biases, opt_params):
- ## calculate synaptic update values
- dWeights, dBiases = HebbianPatchedSynapse._compute_update(
- w_mask, w_bound, is_nonnegative, sign_value, w_decay,
- pre_wght, post_wght, pre, post, weights
- )
- ## conduct a step of optimization - get newly evolved synaptic weight value matrix
- if bias_init != None:
- opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
- else:
- # ignore db since no biases configured
- opt_params, [weights] = opt(opt_params, [weights], [dWeights])
- ## ensure synaptic efficacies adhere to constraints
- weights = _enforce_constraints(weights, w_mask, w_bound, is_nonnegative=is_nonnegative)
- return opt_params, weights, biases, dWeights, dBiases
-
- @resolver(_evolve)
- def evolve(self, opt_params, weights, biases, dWeights, dBiases):
- self.opt_params.set(opt_params)
- self.weights.set(weights)
- self.biases.set(biases)
- self.dWeights.set(dWeights)
- self.dBiases.set(dBiases)
-
- @staticmethod
- def _reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- return (
- preVals, # inputs
- postVals, # outputs
- preVals, # pre
- postVals, # post
- jnp.zeros(shape), # dW
- jnp.zeros(shape[1]), # db
- )
-
- @classmethod
- def help(cls): ## component help function
- properties = {
- "synapse_type": "HebbianSynapse - performs an adaptable synaptic "
- "transformation of inputs to produce output signals; "
- "synapses are adjusted via two-term/factor Hebbian adjustment"
- }
- compartment_props = {
- "inputs":
- {"inputs": "Takes in external input signal values",
- "pre": "Pre-synaptic statistic for Hebb rule (z_j)",
- "post": "Post-synaptic statistic for Hebb rule (z_i)"},
- "states":
- {"weights": "Synapse efficacy/strength parameter values",
- "biases": "Base-rate/bias parameter values",
- "key": "JAX PRNG key"},
- "analytics":
- {"dWeights": "Synaptic weight value adjustment matrix produced at time t",
- "dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"},
- "outputs":
- {"outputs": "Output of synaptic transformation"},
- }
- hyperparams = {
- "shape": "Overall shape of synaptic weight value matrix; number inputs x number outputs",
- "n_sub_models": "The number of submodels in each layer",
- "stride_shape": "Stride shape of overlapping synaptic weight value matrix",
- "batch_size": "Batch size dimension of this component",
- "weight_init": "Initialization conditions for synaptic weight (W) values",
- "bias_init": "Initialization conditions for bias/base-rate (b) values",
- "resist_scale": "Resistance level scaling factor (applied to output of transformation)",
- "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
- "is_nonnegative": "Should synapses be constrained to be non-negative post-updates?",
- "sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0",
- "eta": "Global (fixed) learning rate",
- "pre_wght": "Pre-synaptic weighting coefficient (q_pre)",
- "post_wght": "Post-synaptic weighting coefficient (q_post)",
- "w_bound": "Soft synaptic bound applied to synapses post-update",
- "w_decay": "Synaptic decay term",
- "optim_type": "Choice of optimizer to adjust synaptic weights"
- }
- info = {cls.__name__: properties,
- "compartments": compartment_props,
- "dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
- "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay",
- "hyperparameters": hyperparams}
- return info
-
- @resolver(_reset)
- def reset(self, inputs, outputs, pre, post, dWeights, dBiases):
- self.inputs.set(inputs)
- self.outputs.set(outputs)
- self.pre.set(pre)
- self.post.set(post)
- self.dWeights.set(dWeights)
- self.dBiases.set(dBiases)
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
-if __name__ == '__main__':
- from ngcsimlib.context import Context
- with Context("Bar") as bar:
- Wab = HebbianPatchedSynapse("Wab", (9, 30), 3)
- print(Wab)
- plt.imshow(Wab.weights.value, cmap='gray')
- plt.show()
diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py
deleted file mode 100644
index ef22e06ee..000000000
--- a/ngclearn/components/synapses/patched/patchedSynapse.py
+++ /dev/null
@@ -1,192 +0,0 @@
-import matplotlib.pyplot as plt
-from jax import random, numpy as jnp, jit
-from ngclearn import resolver, Component, Compartment
-from ngclearn.components.jaxComponent import JaxComponent
-from ngclearn.utils import tensorstats
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-import math
-
-
-"""
- π³πΎ π³π
- β¬οΈ β¬οΈ
- ππΎ = 3, π«πΎ π«π
- π
πΎ = π«πΎ / ππΎ π
π = π«π / ππΎ
-
- β―[π€πΎπΒΉ π π ]β―
- ZπΎ β―[ π π€πΎπΒ² π ]β― Zπ
- β―[ π π π€πΎπΒ³]β―
-
-"""
-
-
-def create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
- sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models)
- di, dj = sub_shape
- si, sj = sub_stride
-
- weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj)
- weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
-
- for i in range(n_sub_models):
- start_i = i * di
- end_i = (i + 1) * di + 2 * si
- start_j = i * dj
- end_j = (i + 1) * dj + 2 * sj
-
- shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
-
- weights[start_i : end_i,
- start_j : end_j] = initialize_params(key[2],
- init_kernel=weight_init,
- shape=shape_,
- use_numpy=True)
- if si!=0:
- weights[:si,:] = 0.
- weights[-si:,:] = 0.
- if sj!=0:
- weights[:,:sj] = 0.
- weights[:, -sj:] = 0.
-
- return weights
-
-
-
-class PatchedSynapse(JaxComponent): ## base patched synaptic cable
- # Define Functions
- def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), w_mask=None, weight_init=None, bias_init=None,
- resist_scale=1., p_conn=1., batch_size=1, **kwargs):
- super().__init__(name, **kwargs)
-
- self.Rscale = resist_scale
- self.batch_size = batch_size
- self.weight_init = weight_init
- self.bias_init = bias_init
-
- self.n_sub_models = n_sub_models
- self.sub_stride = stride_shape
-
- tmp_key, *subkeys = random.split(self.key.value, 4)
- if self.weight_init is None:
- info(self.name, "is using default weight initializer!")
- self.weight_init = {"dist": "fan_in_gaussian"}
-
- weights = create_multi_patch_synapses(key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride,
- weight_init=self.weight_init)
-
- self.w_mask = jnp.where(weights!=0, 1, 0)
- self.sub_shape = (shape[0]//n_sub_models, shape[1]//n_sub_models)
-
- self.shape = weights.shape
- self.sub_shape = self.sub_shape[0]+(2*self.sub_stride[0]), self.sub_shape[1]+(2*self.sub_stride[1])
-
- if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed
- mask = random.bernoulli(subkeys[1], p=p_conn, shape=self.shape)
- weights = weights * mask ## sparsify matrix
-
- ## Compartment setup
- preVals = jnp.zeros((self.batch_size, self.shape[0]))
- postVals = jnp.zeros((self.batch_size, self.shape[1]))
- self.inputs = Compartment(preVals)
- self.outputs = Compartment(postVals)
- self.weights = Compartment(weights)
-
- ## Set up (optional) bias values
- if self.bias_init is None:
- info(self.name, "is using default bias value of zero (no bias "
- "kernel provided)!")
- self.biases = Compartment(initialize_params(subkeys[2], bias_init,
- (1, self.shape[1]))
- if bias_init else 0.0)
- @staticmethod
- def _advance_state(Rscale, inputs, weights, biases):
- outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
- return outputs
-
- @resolver(_advance_state)
- def advance_state(self, outputs):
- self.outputs.set(outputs)
-
- @staticmethod
- def _reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- return inputs, outputs
-
- @resolver(_reset)
- def reset(self, inputs, outputs):
- self.inputs.set(inputs)
- self.outputs.set(outputs)
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value,
- biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
-
- @classmethod
- def help(cls): ## component help function
- properties = {
- "synapse_type": "PatchedSynapse - performs a synaptic transformation "
- "of inputs to produce output signals (e.g., a "
- "scaled linear multivariate transformation)"
- }
- compartment_props = {
- "inputs":
- {"inputs": "Takes in external input signal values"},
- "states":
- {"weights": "Synapse efficacy/strength parameter values",
- "biases": "Base-rate/bias parameter values",
- "key": "JAX PRNG key"},
- "outputs":
- {"outputs": "Output of synaptic transformation"},
- }
- hyperparams = {
- "shape": "Overall shape of synaptic weight value matrix; number inputs x number outputs",
- "n_sub_models": "The number of submodels in each layer",
- "stride_shape": "Stride shape of overlapping synaptic weight value matrix",
- "batch_size": "Batch size dimension of this component",
- "weight_init": "Initialization conditions for synaptic weight (W) values",
- "bias_init": "Initialization conditions for bias/base-rate (b) values",
- "resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation",
- "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
- }
- info = {cls.__name__: properties,
- "compartments": compartment_props,
- "dynamics": "outputs = [W * inputs] * Rscale + b",
- "hyperparameters": hyperparams}
- return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
-if __name__ == '__main__':
- from ngcsimlib.context import Context
- with Context("Bar") as bar:
- Wab = PatchedSynapse("Wab", (9, 30), 3)
- print(Wab)
- plt.imshow(Wab.weights.value, cmap='gray')
- plt.show()
diff --git a/ngclearn/components/synapses/patched/staticPatchedSynapse.py b/ngclearn/components/synapses/patched/staticPatchedSynapse.py
deleted file mode 100644
index 6cbcf988e..000000000
--- a/ngclearn/components/synapses/patched/staticPatchedSynapse.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .patchedSynapse import PatchedSynapse
-
-class StaticPatchedSynapse(PatchedSynapse):
- pass
\ No newline at end of file
diff --git a/ngclearn/utils/diffeq/ode_functions.py b/ngclearn/utils/diffeq/ode_functions.py
new file mode 100644
index 000000000..037ee9d1d
--- /dev/null
+++ b/ngclearn/utils/diffeq/ode_functions.py
@@ -0,0 +1,36 @@
+import jax.numpy as jnp
+import jax
+from jax import jit
+from functools import partial
+import matplotlib.pyplot as plt
+
+'''
+
+x0 = jnp.array([3, -1.5])
+'''
+@partial(jit, static_argnums=(0,))
+def linear_2D(t, x, params):
+ '''
+ :param x: 2D vector
+ type: jax array
+ shape:(2,)
+
+ :param t: Unused
+
+ :param params: Unused
+
+ :return: 2D vector: [
+ -0.1 * x[0] + 2.0 * x[1],
+ -2.0 * x[0] - 0.1 * x[1]
+ ]
+ type: jax array
+ shape:(2,)
+
+ ------------------------------------------
+ * suggested init value-
+ x0 = jnp.array([3, -1.5])
+ '''
+ coeff = jnp.array([[-0.1, 2], [-2, -0.1]]).T
+ dfx_ = jnp.matmul(x, coeff)
+
+ return dfx_
\ No newline at end of file
diff --git a/ngclearn/utils/diffeq/ode_solver.py b/ngclearn/utils/diffeq/ode_solver.py
new file mode 100644
index 000000000..565043eb3
--- /dev/null
+++ b/ngclearn/utils/diffeq/ode_solver.py
@@ -0,0 +1,292 @@
+import jax.numpy as jnp
+from jax import jit
+from functools import partial
+from jax.lax import scan as _scan
+
+
+
+@jit
+def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition
+ sum = 0
+
+ for arg, val in zip(args, kwargs.values()):
+ sum = sum + val * arg
+ return sum
+
+@partial(jit, static_argnums=(3, 4))
+def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine
+ _t = t + dt
+ _x = x * x_scale + dx_dt * dt
+ return _t, _x
+
+@partial(jit, static_argnums=(1, 2, 3, 4, ))
+def euler(carry, dfx, dt, params, x_scale=1.):
+ """
+ Iteratively integrates one step forward via the Euler method, i.e., a
+ first-order Runge-Kutta (RK-1) step.
+
+ Args:
+ t: current time variable to advance by dt
+
+ x: current variable values to advance/iteratively integrate (at time `t`)
+
+ dfx: (ordinary) differential equation co-routine (as implemented in an
+ ngc-learn component)
+
+ dt: integration time step (also referred to as `h` in mathematics)
+
+ params: tuple containing configuration values/hyper-parameters for the
+ (ordinary) differential equation an ngc-learn component will provide
+
+ x_scale: dampening factor to scale `x` by (Default: 1)
+
+ Returns:
+ variable values iteratively integrated/advanced to next step (`t + dt`)
+ """
+ t, x = carry
+
+ dx_dt = dfx(t, x, params)
+ _t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
+
+ new_carry = (_t, _x)
+ return new_carry, (new_carry, carry)
+
+@partial(jit, static_argnums=(1, 2, 3, 4, ))
+def heun(carry, dfx, dt, params, x_scale=1.):
+ """
+ Iteratively integrates one step forward via Heun's method, i.e., a
+ second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes
+ two (differential) function evaluations to estimate the solution at a given
+ point in time.
+ (Note: ngc-learn internally recognizes "rk2_heun" or "heun" for this routine)
+
+ | Reference:
+ | Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
+ | differential equations and differential-algebraic equations. Society for
+ | Industrial and Applied Mathematics, 1998.
+
+ Args:
+ t: current time variable to advance by dt
+
+ x: current variable values to advance/iteratively integrate (at time `t`)
+
+ dfx: (ordinary) differential equation co-routine (as implemented in an
+ ngc-learn component)
+
+ dt: integration time step (also referred to as `h` in mathematics)
+
+ params: tuple containing configuration values/hyper-parameters for the
+ (ordinary) differential equation an ngc-learn component will provide
+
+ x_scale: dampening factor to scale `x` by (Default: 1)
+
+ Returns:
+ variable values iteratively integrated/advanced to next step (`t + dt`)
+ """
+ t, x = carry
+
+ dx_dt = dfx(t, x, params)
+ _t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
+ _dx_dt = dfx(_t, _x, params)
+ summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
+ _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
+
+ new_carry = (_t, _x)
+ return new_carry, (new_carry, carry)
+
+@partial(jit, static_argnums=(1, 2, 3, 4, ))
+def rk2(carry, dfx, dt, params, x_scale=1.):
+ """
+ Iteratively integrates one step forward via the midpoint method, i.e., a
+ second-order Runge-Kutta (RK-2) step.
+ (Note: ngc-learn internally recognizes "rk2" or "midpoint" for this routine)
+
+ | Reference:
+ | Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
+ | differential equations and differential-algebraic equations. Society for
+ | Industrial and Applied Mathematics, 1998.
+
+ Args:
+ t: current time variable to advance by dt
+
+ x: current variable values to advance/iteratively integrate (at time `t`)
+
+ dfx: (ordinary) differential equation co-routine (as implemented in an
+ ngc-learn component)
+
+ dt: integration time step (also referred to as `h` in mathematics)
+
+ params: tuple containing configuration values/hyper-parameters for the
+ (ordinary) differential equation an ngc-learn component will provide
+
+ x_scale: dampening factor to scale `x` by (Default: 1)
+
+ Returns:
+ variable values iteratively integrated/advanced to next step (`t + dt`)
+ """
+ t, x = carry
+
+ f_1 = dfx(t, x, params)
+ t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale)
+ f_2 = dfx(t1, x1, params)
+ _t, _x = _step_forward(t, x, f_2, dt, x_scale)
+
+ new_carry = (_t, _x)
+ return new_carry, (new_carry, carry)
+
+@partial(jit, static_argnums=(1, 2, 3, 4, ))
+def rk4(carry, dfx, dt, params, x_scale=1.):
+ """
+ Iteratively integrates one step forward via the midpoint method, i.e., a
+ fourth-order Runge-Kutta (RK-4) step.
+ (Note: ngc-learn internally recognizes "rk4" or this routine)
+
+ | Reference:
+ | Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
+ | differential equations and differential-algebraic equations. Society for
+ | Industrial and Applied Mathematics, 1998.
+
+ Args:
+ t: current time variable to advance by dt
+
+ x: current variable values to advance/iteratively integrate (at time `t`)
+
+ dfx: (ordinary) differential equation co-routine (as implemented in an
+ ngc-learn component)
+
+ dt: integration time step (also referred to as `h` in mathematics)
+
+ params: tuple containing configuration values/hyper-parameters for the
+ (ordinary) differential equation an ngc-learn component will provide
+
+ x_scale: dampening factor to scale `x` by (Default: 1)
+
+ Returns:
+ variable values iteratively integrated/advanced to next step (`t + dt`)
+ """
+
+ t, x = carry
+
+ f_1 = dfx(t, x, params)
+ t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale)
+
+ f_2 = dfx(t1, x1, params)
+ t2, x2 = _step_forward(t, x, f_2, dt * 0.5, x_scale)
+
+ f_3 = dfx(t2, x2, params)
+ t3, x3 = _step_forward(t, x, f_3, dt, x_scale)
+
+ f_4 = dfx(t3, x3, params)
+
+ _dx_dt = _sum_combine(f_1, f_2, f_3, f_4, w_f1=1, w_f2=2, w_f3=2, w_f4=1)
+ _t, _x = _step_forward(t, x, _dx_dt, dt / 6, x_scale)
+
+ new_carry = (_t, _x)
+ return new_carry, (new_carry, carry)
+
+@partial(jit, static_argnums=(1, 2, 3, 4,))
+def ralston(carry, dfx, dt, params, x_scale=1.):
+ """
+ Iteratively integrates one step forward via Ralston's method, i.e., a
+ second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes
+ two (differential) function evaluations to estimate the solution at a given
+ point in time.
+ (Note: ngc-learn internally recognizes "rk2_ralston" or "ralston" for this
+ routine)
+
+ | Reference:
+ | Ralston, Anthony. "Runge-Kutta methods with minimum error bounds."
+ | Mathematics of computation 16.80 (1962): 431-437.
+
+ Args:
+ t: current time variable to advance by dt
+
+ x: current variable values to advance/iteratively integrate (at time `t`)
+
+ dfx: (ordinary) differential equation co-routine (as implemented in an
+ ngc-learn component)
+
+ dt: integration time step (also referred to as `h` in mathematics)
+
+ params: tuple containing configuration values/hyper-parameters for the
+ (ordinary) differential equation an ngc-learn component will provide
+
+ x_scale: dampening factor to scale `x` by (Default: 1)
+
+ Returns:
+ variable values iteratively integrated/advanced to next step (`t + dt`)
+ """
+
+ t, x = carry
+
+ dx_dt = dfx(t, x, params) ## k1
+ tm, xm = _step_forward(t, x, dx_dt, dt * 0.75, x_scale)
+ _dx_dt = dfx(tm, xm, params) ## k2
+ ## Note: new step is a weighted combination of k1 and k2
+ summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=(1./3.), weight2=(2./3.))
+ _t, _x = _step_forward(t, x, summed_dx_dt, dt, x_scale)
+
+ new_carry = (_t, _x)
+ return new_carry, (new_carry, carry)
+
+
+@partial(jit, static_argnums=(0, 3, 4, 5, 6, 7, 8))
+def solve_ode(method_name, t0, x0, T, dfx, dt, params=None, x_scale=1., sols_only=True):
+
+ if method_name =='euler':
+ method = euler
+ elif method_name == 'heun':
+ method = heun
+ elif method_name == 'rk2':
+ method = rk2
+ elif method_name =='rk4':
+ method = rk4
+ elif method_name =='ralston':
+ method = ralston
+
+ def scanner(carry, _):
+ return method(carry, dfx, dt, params, x_scale)
+
+ x_T, (xs_next, xs_carry) = _scan(scanner, init=(t0, x0), xs=jnp.arange(T))
+
+ if not sols_only:
+ return x_T, xs_next, xs_carry
+
+ return xs_next
+
+
+
+########################################################################################
+########################################################################################
+if __name__ == '__main__':
+ import matplotlib.pyplot as plt
+ from ode_functions import linear_2D
+
+ dfx = linear_2D
+ x0 = jnp.array([3, -1.5])
+
+ dt = 1e-2
+ t0 = 0.
+ T = 800
+
+ (t_final, x_final), (ts_sol, sol_euler), (ts_carr, xs_carr) = solve_ode('euler', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
+ (_, x_final), (_, sol_heun), (_, xs_carr) = solve_ode('heun', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
+ (_, x_final), (_, sol_rk2), (_, xs_carr) = solve_ode('rk2', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
+ (_, x_final), (_, sol_rk4), (_, xs_carr) = solve_ode('rk4', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
+ (_, x_final), (_, sol_ralston), (_, xs_carr) = solve_ode('ralston', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
+
+
+ plt.plot(ts_sol, sol_euler[:, 0], label='x0-Euler')
+ plt.plot(ts_sol, sol_heun[:, 0], label='x0-Heun')
+ plt.plot(ts_sol, sol_rk2[:, 0], label='x0-RK2')
+ plt.plot(ts_sol, sol_rk4[:, 0], label='x0-RK4')
+ plt.plot(ts_sol, sol_ralston[:, 0], label='x0-Ralston')
+
+ plt.plot(ts_sol, sol_euler[:, 1], label='x1-Euler')
+ plt.plot(ts_sol, sol_heun[:, 1], label='x1-Heun')
+ plt.plot(ts_sol, sol_rk2[:, 1], label='x1-RK2')
+ plt.plot(ts_sol, sol_rk4[:, 1], label='x1-RK4')
+ plt.plot(ts_sol, sol_ralston[:, 1], label='x1-Ralston')
+ plt.legend(loc='best')
+ plt.grid()
+ plt.show()
diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py
index 460fd2f48..0a55a55da 100755
--- a/ngclearn/utils/diffeq/ode_utils.py
+++ b/ngclearn/utils/diffeq/ode_utils.py
@@ -44,8 +44,9 @@ def get_integrator_code(integrationType):
@jit
def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition
sum = 0
+
for arg, val in zip(args, kwargs.values()):
- sum += arg * val
+ sum = sum + val * arg
return sum
@jit
@@ -54,6 +55,7 @@ def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine
_x = x * x_scale + dx_dt * dt
return _t, _x
+@partial(jit, static_argnums=(2, ))
def step_euler(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the Euler method, i.e., a
@@ -81,6 +83,7 @@ def step_euler(t, x, dfx, dt, params, x_scale=1.):
_t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
return _t, _x
+@partial(jit, static_argnums=(2, ))
def step_heun(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via Heun's method, i.e., a
@@ -113,13 +116,15 @@ def step_heun(t, x, dfx, dt, params, x_scale=1.):
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
dx_dt = dfx(t, x, params)
+
_t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
_dx_dt = dfx(_t, _x, params)
summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
+
_, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
return _t, _x
-
+@partial(jit, static_argnums=(2, ))
def step_rk2(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -150,13 +155,13 @@ def step_rk2(t, x, dfx, dt, params, x_scale=1.):
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
f_1 = dfx(t, x, params)
- t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale)
+ t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale)
f_2 = dfx(t1, x1, params)
_t, _x = _step_forward(t, x, f_2, dt, x_scale)
return _t, _x
-
+@partial(jit, static_argnums=(2, ))
def step_rk4(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -201,8 +206,7 @@ def step_rk4(t, x, dfx, dt, params, x_scale=1.):
_t, _x = _step_forward(t, x, _dx_dt, dt / 6, x_scale)
return _t, _x
-
-
+@partial(jit, static_argnums=(2, ))
def step_ralston(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via Ralston's method, i.e., a
diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py
index 59bc8a32a..704ec7e0f 100755
--- a/ngclearn/utils/model_utils.py
+++ b/ngclearn/utils/model_utils.py
@@ -83,6 +83,10 @@ def create_function(fun_name, args=None):
if fun_name == "tanh":
fx = tanh
dfx = d_tanh
+ elif fun_name == "sine":
+ fx = sine
+ dfx = d_sine
+ omega_0 = args
elif fun_name == "sigmoid":
fx = sigmoid
dfx = d_sigmoid
@@ -266,6 +270,34 @@ def d_relu(x):
"""
return (x >= 0.).astype(jnp.float32)
+@jit
+def sine(x, omega_0=30):
+ """
+ f(x) = sin(x * omega_0).
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) value
+ """
+ return jnp.sin(omega_0 * x)
+
+@jit
+def d_sine(x, omega_0=30):
+ """
+ frequency = omega_0
+ frequency * cos(x * frequency).
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) value
+ """
+ return omega_0 * jnp.cos(omega_0 * x)
+
+
@jit
def tanh(x):
"""