From 05ea912eb14ec0e29c5ad9f1fb2b57815df5c384 Mon Sep 17 00:00:00 2001 From: ago109 Date: Wed, 24 Jul 2024 13:10:54 -0400 Subject: [PATCH] modded bernoulli-cell to include max-frequency constraint --- .../input_encoders/bernoulliCell.py | 47 ++++++++++++++++--- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index da090bfb..61e74031 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -2,6 +2,7 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit from ngclearn.utils import tensorstats +from functools import partial @jit def _update_times(t, s, tols): @@ -37,9 +38,33 @@ def _sample_bernoulli(dkey, data): s_t = random.bernoulli(dkey, p=data).astype(jnp.float32) return s_t +@partial(jit, static_argnums=[3]) +def _sample_constrained_bernoulli(dkey, data, dt, fmax=63.75): + """ + Samples a Bernoulli spike train on-the-fly that is constrained to emit + at a particular rate over a time window. + + Args: + dkey: JAX key to drive stochasticity/noise + + data: sensory data (vector/matrix) + + dt: integration time constant + + fmax: maximum frequency (Hz) + + Returns: + binary spikes + """ + pspike = data * (dt/1000.) * fmax + eps = random.uniform(dkey, data.shape, minval=0., maxval=1., dtype=jnp.float32) + s_t = (eps < pspike).astype(jnp.float32) + return s_t + class BernoulliCell(JaxComponent): """ - A Bernoulli cell that produces Bernoulli-distributed spikes on-the-fly. + A Bernoulli cell that produces variations of Bernoulli-distributed spikes + on-the-fly (including constrained-rate trains). | --- Cell Input Compartments: --- | inputs - input (takes in external signals) @@ -53,12 +78,17 @@ class BernoulliCell(JaxComponent): name: the string name of this cell n_units: number of cellular entities (neural population size) + + max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.) """ # Define Functions - def __init__(self, name, n_units, batch_size=1, **kwargs): + def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs): super().__init__(name, **kwargs) + ## Constrained Bernoulli meta-parameters + self.max_freq = max_freq ## maximum frequency (in Hertz/Hz) + ## Layer Size Setup self.batch_size = batch_size self.n_units = n_units @@ -70,11 +100,16 @@ def __init__(self, name, n_units, batch_size=1, **kwargs): self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike @staticmethod - def _advance_state(t, key, inputs, tols): + def _advance_state(t, dt, max_freq, key, inputs, tols): key, *subkeys = random.split(key, 2) - outputs = _sample_bernoulli(subkeys[0], data=inputs) - timeOfLastSpike = _update_times(t, outputs, tols) - return outputs, timeOfLastSpike, key + if max_freq > 0.: + outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate + subkeys[0], data=inputs, dt=dt, fmax=max_freq + ) + else: + outputs = _sample_bernoulli(subkeys[0], data=inputs) + tols = _update_times(t, outputs, tols) + return outputs, tols, key @resolver(_advance_state) def advance_state(self, outputs, tols, key):