Skip to content

Commit

Permalink
modded bernoulli-cell to include max-frequency constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 24, 2024
1 parent 42167cb commit 05ea912
Showing 1 changed file with 41 additions and 6 deletions.
47 changes: 41 additions & 6 deletions ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from ngclearn.utils import tensorstats
from functools import partial

@jit
def _update_times(t, s, tols):
Expand Down Expand Up @@ -37,9 +38,33 @@ def _sample_bernoulli(dkey, data):
s_t = random.bernoulli(dkey, p=data).astype(jnp.float32)
return s_t

@partial(jit, static_argnums=[3])
def _sample_constrained_bernoulli(dkey, data, dt, fmax=63.75):
"""
Samples a Bernoulli spike train on-the-fly that is constrained to emit
at a particular rate over a time window.
Args:
dkey: JAX key to drive stochasticity/noise
data: sensory data (vector/matrix)
dt: integration time constant
fmax: maximum frequency (Hz)
Returns:
binary spikes
"""
pspike = data * (dt/1000.) * fmax
eps = random.uniform(dkey, data.shape, minval=0., maxval=1., dtype=jnp.float32)
s_t = (eps < pspike).astype(jnp.float32)
return s_t

class BernoulliCell(JaxComponent):
"""
A Bernoulli cell that produces Bernoulli-distributed spikes on-the-fly.
A Bernoulli cell that produces variations of Bernoulli-distributed spikes
on-the-fly (including constrained-rate trains).
| --- Cell Input Compartments: ---
| inputs - input (takes in external signals)
Expand All @@ -53,12 +78,17 @@ class BernoulliCell(JaxComponent):
name: the string name of this cell
n_units: number of cellular entities (neural population size)
max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
"""

# Define Functions
def __init__(self, name, n_units, batch_size=1, **kwargs):
def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## Constrained Bernoulli meta-parameters
self.max_freq = max_freq ## maximum frequency (in Hertz/Hz)

## Layer Size Setup
self.batch_size = batch_size
self.n_units = n_units
Expand All @@ -70,11 +100,16 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike

@staticmethod
def _advance_state(t, key, inputs, tols):
def _advance_state(t, dt, max_freq, key, inputs, tols):
key, *subkeys = random.split(key, 2)
outputs = _sample_bernoulli(subkeys[0], data=inputs)
timeOfLastSpike = _update_times(t, outputs, tols)
return outputs, timeOfLastSpike, key
if max_freq > 0.:
outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate
subkeys[0], data=inputs, dt=dt, fmax=max_freq
)
else:
outputs = _sample_bernoulli(subkeys[0], data=inputs)
tols = _update_times(t, outputs, tols)
return outputs, tols, key

@resolver(_advance_state)
def advance_state(self, outputs, tols, key):
Expand Down

0 comments on commit 05ea912

Please sign in to comment.